Commit 2ce12046 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443199105
parent b4d128f1
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
import enum
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -31,6 +32,14 @@ States = Dict[str, tf.Tensor] ...@@ -31,6 +32,14 @@ States = Dict[str, tf.Tensor]
Activation = Union[str, Callable] Activation = Union[str, Callable]
# String constants.
class FeatureFusion(str, enum.Enum):
PYRAMID_FUSION = 'pyramid_fusion'
PANOPTIC_FPN_FUSION = 'panoptic_fpn_fusion'
DEEPLABV3PLUS = 'deeplabv3plus'
DEEPLABV3PLUS_SUM_TO_MERGE = 'deeplabv3plus_sum_to_merge'
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitationQuantized( class SqueezeExcitationQuantized(
helper.LayerQuantizerHelper, helper.LayerQuantizerHelper,
...@@ -237,10 +246,11 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -237,10 +246,11 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
prediction layer. prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied. generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If feature_fusion: One of `deeplabv3plus`, `deeplabv3plus_sum_to_merge`,
`deeplabv3plus`, features from decoder_features[level] will be fused `pyramid_fusion`, or None. If `deeplabv3plus`, features from
with low level feature maps from backbone. If `pyramid_fusion`, decoder_features[level] will be fused with low level feature maps from
multiscale features will be resized and fused at the target level. backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level.
decoder_min_level: An `int` of minimum level from decoder to use in decoder_min_level: An `int` of minimum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`. `panoptic_fpn_fusion`.
...@@ -327,7 +337,9 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -327,7 +337,9 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
} }
if self._config_dict['feature_fusion'] == 'deeplabv3plus': if self._config_dict['feature_fusion'] in [
FeatureFusion.DEEPLABV3PLUS, FeatureFusion.DEEPLABV3PLUS_SUM_TO_MERGE
]:
# Deeplabv3+ feature fusion layers. # Deeplabv3+ feature fusion layers.
self._dlv3p_conv = helper.Conv2DQuantized( self._dlv3p_conv = helper.Conv2DQuantized(
kernel_size=1, kernel_size=1,
...@@ -388,6 +400,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -388,6 +400,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
backbone_shape[1], backbone_shape[2], interpolation='bilinear') backbone_shape[1], backbone_shape[2], interpolation='bilinear')
self._concat_layer = helper.ConcatenateQuantized(axis=self._bn_axis) self._concat_layer = helper.ConcatenateQuantized(axis=self._bn_axis)
self._add_layer = tf.keras.layers.Add()
super().build(input_shape) super().build(input_shape)
...@@ -412,14 +425,16 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -412,14 +425,16 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
segmentation prediction mask: A `tf.Tensor` of the segmentation mask segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features. scores predicted from input features.
""" """
if self._config_dict['feature_fusion'] in ('pyramid_fusion', if self._config_dict['feature_fusion'] in (
'panoptic_fpn_fusion'): FeatureFusion.PYRAMID_FUSION, FeatureFusion.PANOPTIC_FPN_FUSION):
raise ValueError( raise ValueError(
'The feature fusion method `pyramid_fusion` is not supported in QAT.') 'The feature fusion method `pyramid_fusion` is not supported in QAT.')
backbone_output = inputs[0] backbone_output = inputs[0]
decoder_output = inputs[1] decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus': if self._config_dict['feature_fusion'] in {
FeatureFusion.DEEPLABV3PLUS, FeatureFusion.DEEPLABV3PLUS_SUM_TO_MERGE
}:
# deeplabv3+ feature fusion. # deeplabv3+ feature fusion.
x = decoder_output[str(self._config_dict['level'])] if isinstance( x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output decoder_output, dict) else decoder_output
...@@ -429,7 +444,10 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -429,7 +444,10 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
y = self._activation_layer(y) y = self._activation_layer(y)
x = self._resizing_layer(x) x = self._resizing_layer(x)
x = tf.cast(x, dtype=y.dtype) x = tf.cast(x, dtype=y.dtype)
x = self._concat_layer([x, y]) if self._config_dict['feature_fusion'] == FeatureFusion.DEEPLABV3PLUS:
x = self._concat_layer([x, y])
else:
x = self._add_layer([x, y])
else: else:
x = decoder_output[str(self._config_dict['level'])] if isinstance( x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output decoder_output, dict) else decoder_output
......
...@@ -24,12 +24,15 @@ from official.projects.qat.vision.modeling.layers import nn_layers ...@@ -24,12 +24,15 @@ from official.projects.qat.vision.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase): class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
('deeplabv3plus', 1), ('deeplabv3plus', 1, 128, 128),
('deeplabv3plus', 2), ('deeplabv3plus', 2, 128, 128),
('deeplabv3', 1), ('deeplabv3', 1, 128, 64),
('deeplabv3', 2), ('deeplabv3', 2, 128, 64),
('deeplabv3plus_sum_to_merge', 1, 64, 128),
('deeplabv3plus_sum_to_merge', 2, 64, 128),
) )
def test_segmentation_head_creation(self, feature_fusion, upsample_factor): def test_segmentation_head_creation(self, feature_fusion, upsample_factor,
low_level_num_filters, expected_shape):
input_size = 128 input_size = 128
decoder_outupt_size = input_size // 2 decoder_outupt_size = input_size // 2
...@@ -42,14 +45,11 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,14 +45,11 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
level=4, level=4,
upsample_factor=upsample_factor, upsample_factor=upsample_factor,
low_level=2, low_level=2,
low_level_num_filters=128, low_level_num_filters=low_level_num_filters,
feature_fusion=feature_fusion) feature_fusion=feature_fusion)
features = segmentation_head((backbone_output, decoder_output)) features = segmentation_head((backbone_output, decoder_output))
expected_shape = (
input_size
if feature_fusion == 'deeplabv3plus' else decoder_outupt_size)
self.assertAllEqual([ self.assertAllEqual([
2, expected_shape * upsample_factor, expected_shape * upsample_factor, 5 2, expected_shape * upsample_factor, expected_shape * upsample_factor, 5
], features.shape.as_list()) ], features.shape.as_list())
......
...@@ -233,8 +233,9 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -233,8 +233,9 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer. prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied. generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, feature_fusion: One of the constants in nn_layers.FeatureFusion, namely
`panoptic_fpn_fusion`, or None. If `deeplabv3plus`, features from `deeplabv3plus`, `pyramid_fusion`, `panoptic_fpn_fusion`,
`deeplabv3plus_sum_to_merge`, or None. If `deeplabv3plus`, features from
decoder_features[level] will be fused with low level feature maps from decoder_features[level] will be fused with low level feature maps from
backbone. If `pyramid_fusion`, multiscale features will be resized and backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level. fused at the target level.
...@@ -245,10 +246,12 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -245,10 +246,12 @@ class SegmentationHead(tf.keras.layers.Layer):
feature fusion. It is only used when feature_fusion is set to feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`. `panoptic_fpn_fusion`.
low_level: An `int` of backbone level to be used for feature fusion. It is low_level: An `int` of backbone level to be used for feature fusion. It is
used when feature_fusion is set to `deeplabv3plus`. used when feature_fusion is set to `deeplabv3plus` or
`deeplabv3plus_sum_to_merge`.
low_level_num_filters: An `int` of reduced number of filters for the low low_level_num_filters: An `int` of reduced number of filters for the low
level features before fusing it with higher level features. It is only level features before fusing it with higher level features. It is only
used when feature_fusion is set to `deeplabv3plus`. used when feature_fusion is set to `deeplabv3plus` or
`deeplabv3plus_sum_to_merge`.
num_decoder_filters: An `int` of number of filters in the decoder outputs. num_decoder_filters: An `int` of number of filters in the decoder outputs.
It is only used when feature_fusion is set to `panoptic_fpn_fusion`. It is only used when feature_fusion is set to `panoptic_fpn_fusion`.
activation: A `str` that indicates which activation is used, e.g. 'relu', activation: A `str` that indicates which activation is used, e.g. 'relu',
...@@ -312,7 +315,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -312,7 +315,8 @@ class SegmentationHead(tf.keras.layers.Layer):
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
} }
if self._config_dict['feature_fusion'] == 'deeplabv3plus': if self._config_dict['feature_fusion'] in {'deeplabv3plus',
'deeplabv3plus_sum_to_merge'}:
# Deeplabv3+ feature fusion layers. # Deeplabv3+ feature fusion layers.
self._dlv3p_conv = conv_op( self._dlv3p_conv = conv_op(
kernel_size=1, kernel_size=1,
...@@ -398,7 +402,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -398,7 +402,8 @@ class SegmentationHead(tf.keras.layers.Layer):
backbone_output = inputs[0] backbone_output = inputs[0]
decoder_output = inputs[1] decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus': if self._config_dict['feature_fusion'] in {'deeplabv3plus',
'deeplabv3plus_sum_to_merge'}:
# deeplabv3+ feature fusion # deeplabv3+ feature fusion
x = decoder_output[str(self._config_dict['level'])] if isinstance( x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output decoder_output, dict) else decoder_output
...@@ -410,7 +415,10 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -410,7 +415,10 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.image.resize( x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR) x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.cast(x, dtype=y.dtype) x = tf.cast(x, dtype=y.dtype)
x = tf.concat([x, y], axis=self._bn_axis) if self._config_dict['feature_fusion'] == 'deeplabv3plus':
x = tf.concat([x, y], axis=self._bn_axis)
else:
x = tf.keras.layers.Add()([x, y])
elif self._config_dict['feature_fusion'] == 'pyramid_fusion': elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
if not isinstance(decoder_output, dict): if not isinstance(decoder_output, dict):
raise ValueError('Only support dictionary decoder_output.') raise ValueError('Only support dictionary decoder_output.')
......
...@@ -30,7 +30,9 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -30,7 +30,9 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
(2, 'panoptic_fpn_fusion', 2, 5), (2, 'panoptic_fpn_fusion', 2, 5),
(2, 'panoptic_fpn_fusion', 2, 6), (2, 'panoptic_fpn_fusion', 2, 6),
(3, 'panoptic_fpn_fusion', 3, 5), (3, 'panoptic_fpn_fusion', 3, 5),
(3, 'panoptic_fpn_fusion', 3, 6)) (3, 'panoptic_fpn_fusion', 3, 6),
(3, 'deeplabv3plus', 3, 6),
(3, 'deeplabv3plus_sum_to_merge', 3, 6))
def test_forward(self, level, feature_fusion, def test_forward(self, level, feature_fusion,
decoder_min_level, decoder_max_level): decoder_min_level, decoder_max_level):
backbone_features = { backbone_features = {
...@@ -52,6 +54,8 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -52,6 +54,8 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
head = segmentation_heads.SegmentationHead( head = segmentation_heads.SegmentationHead(
num_classes=10, num_classes=10,
level=level, level=level,
low_level=decoder_min_level,
low_level_num_filters=64,
feature_fusion=feature_fusion, feature_fusion=feature_fusion,
decoder_min_level=decoder_min_level, decoder_min_level=decoder_min_level,
decoder_max_level=decoder_max_level, decoder_max_level=decoder_max_level,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment