Commit c1c9bb0f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Remove duplicate functions and classes in QAT project

PiperOrigin-RevId: 437133811
parent d8f05122
......@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
from official.vision.modeling.layers import nn_layers
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
if not other or not isinstance(other, NoOpActivation):
return False
return True
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def _quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs),
quantize_config)
return constructor
# This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply
# QAT.
@tf.keras.utils.register_keras_serializable(package='Vision')
......@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = _quantize_wrapped_layer(
self._norm = helper.quantize_wrapped_layer(
tf.keras.layers.experimental.SyncBatchNormalization,
configs.NoOpQuantizeConfig())
self._norm_with_quantize = _quantize_wrapped_layer(
self._norm_with_quantize = helper.quantize_wrapped_layer(
tf.keras.layers.experimental.SyncBatchNormalization,
configs.Default8BitOutputQuantizeConfig())
else:
self._norm = _quantize_wrapped_layer(
tf.keras.layers.BatchNormalization,
configs.NoOpQuantizeConfig())
self._norm_with_quantize = _quantize_wrapped_layer(
self._norm = helper.quantize_wrapped_layer(
tf.keras.layers.BatchNormalization, configs.NoOpQuantizeConfig())
self._norm_with_quantize = helper.quantize_wrapped_layer(
tf.keras.layers.BatchNormalization,
configs.Default8BitOutputQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last':
......@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling."""
conv2d_quantized = _quantize_wrapped_layer(
conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(
['kernel'], ['activation'], False))
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False))
if self._use_projection:
if self._resnetd_shortcut:
self._shortcut0 = tf.keras.layers.AveragePooling2D(
......@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
else:
self._shortcut = conv2d_quantized(
filters=self._filters * 4,
......@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm0 = self._norm_with_quantize(
axis=self._bn_axis,
......@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm3 = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -392,10 +359,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf.keras.layers.BatchNormalization)
self._norm_with_quantize = _quantize_wrapped_layer(
self._norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig())
self._norm = _quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
self._norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
if self._use_explicit_padding and self._kernel_size > 1:
padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size)
self._pad = tf.keras.layers.ZeroPadding2D(padding_size)
conv2d_quantized = _quantize_wrapped_layer(
conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(
['kernel'], ['activation'], not self._use_normalization))
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
not self._use_normalization))
self._conv0 = conv2d_quantized(
filters=self._filters,
kernel_size=self._kernel_size,
......@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
if self._use_normalization:
self._norm0 = self._norm_by_activation(self._activation)(
axis=self._bn_axis,
......@@ -579,10 +546,10 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf.keras.layers.BatchNormalization)
self._norm_with_quantize = _quantize_wrapped_layer(
self._norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig())
self._norm = _quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
self._norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling."""
conv2d_quantized = _quantize_wrapped_layer(
conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(
['kernel'], ['activation'], False))
depthwise_conv2d_quantized = _quantize_wrapped_layer(
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False))
depthwise_conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig(
['depthwise_kernel'], ['activation'], False))
configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['activation'], False))
expand_filters = self._in_filters
if self._expand_ratio > 1:
# First 1x1 conv for channel expansion.
......@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm0 = self._norm_by_activation(self._activation)(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
depthwise_initializer=self._kernel_initializer,
depthwise_regularizer=self._depthsize_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm1 = self._norm_by_activation(self._depthwise_activation)(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._norm2 = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
......
......@@ -14,7 +14,7 @@
"""Contains common building blocks for neural networks."""
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf
......@@ -31,36 +31,6 @@ States = Dict[str, tf.Tensor]
Activation = Union[str, Callable]
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
return isinstance(other, NoOpActivation)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def _quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs),
quantize_config)
return constructor
@tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitationQuantized(
helper.LayerQuantizerHelper,
......@@ -154,14 +124,13 @@ class SqueezeExcitationQuantized(
return x
def build(self, input_shape):
conv2d_quantized = _quantize_wrapped_layer(
conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(
['kernel'], ['activation'], False))
conv2d_quantized_output_quantized = _quantize_wrapped_layer(
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False))
conv2d_quantized_output_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(
['kernel'], ['activation'], True))
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], True))
num_reduced_filters = nn_layers.make_divisible(
max(1, int(self._in_filters * self._se_ratio)),
divisor=self._divisible_by,
......@@ -176,7 +145,7 @@ class SqueezeExcitationQuantized(
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._se_expand = conv2d_quantized_output_quantized(
filters=self._out_filters,
......@@ -187,7 +156,7 @@ class SqueezeExcitationQuantized(
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._multiply = tfmot.quantization.keras.QuantizeWrapperV2(
tf.keras.layers.Multiply(),
......@@ -342,14 +311,14 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
backbone_shape = input_shape[0]
use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
random_initializer = tf.keras.initializers.RandomNormal(stddev=0.01)
conv2d_quantized = _quantize_wrapped_layer(
conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False))
conv2d_quantized_output_quantized = _quantize_wrapped_layer(
conv2d_quantized_output_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], True))
depthwise_conv2d_quantized = _quantize_wrapped_layer(
depthwise_conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['activation'], False))
......@@ -365,11 +334,13 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn'] else
tf.keras.layers.BatchNormalization)
norm_with_quantize = _quantize_wrapped_layer(
norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig())
norm = norm_with_quantize if self._config_dict['activation'] not in [
'relu', 'relu6'
] else _quantize_wrapped_layer(norm_layer, configs.NoOpQuantizeConfig())
if self._config_dict['activation'] not in ['relu', 'relu6']:
norm = norm_with_quantize
else:
norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
bn_kwargs = {
'axis': self._bn_axis,
......@@ -387,7 +358,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv',
filters=self._config_dict['low_level_num_filters'],
activation=NoOpActivation())
activation=helper.NoOpActivation())
self._dlv3p_norm = norm(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
......@@ -406,7 +377,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
depthwise_initializer=random_initializer,
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1,
activation=NoOpActivation()))
activation=helper.NoOpActivation()))
norm_name = 'segmentation_head_depthwise_norm_{}'.format(i)
self._norms.append(norm(name=norm_name, **bn_kwargs))
conv_name = 'segmentation_head_conv_{}'.format(i)
......@@ -414,7 +385,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
conv2d_quantized(
name=conv_name,
filters=self._config_dict['num_filters'],
activation=NoOpActivation(),
activation=helper.NoOpActivation(),
**conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i)
self._norms.append(norm(name=norm_name, **bn_kwargs))
......@@ -428,9 +399,9 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
activation=NoOpActivation())
activation=helper.NoOpActivation())
upsampling = _quantize_wrapped_layer(
upsampling = helper.quantize_wrapped_layer(
tf.keras.layers.UpSampling2D,
configs.Default8BitQuantizeConfig([], [], True))
self._upsampling_layer = upsampling(
......@@ -440,7 +411,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
self._resizing_layer = tf.keras.layers.Resizing(
backbone_shape[1], backbone_shape[2], interpolation='bilinear')
concat = _quantize_wrapped_layer(
concat = helper.quantize_wrapped_layer(
tf.keras.layers.Concatenate,
configs.Default8BitQuantizeConfig([], [], True))
self._concat_layer = concat(axis=self._bn_axis)
......@@ -589,17 +560,19 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization
if self._use_sync_bn else tf.keras.layers.BatchNormalization)
norm_with_quantize = _quantize_wrapped_layer(
norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig())
norm = norm_with_quantize if self._activation not in [
'relu', 'relu6'
] else _quantize_wrapped_layer(norm_layer, configs.NoOpQuantizeConfig())
if self._activation not in ['relu', 'relu6']:
norm = norm_with_quantize
else:
norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
conv2d_quantized = _quantize_wrapped_layer(
conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False))
depthwise_conv2d_quantized_output_quantized = _quantize_wrapped_layer(
depthwise_conv2d_quantized_output_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['activation'], True))
......@@ -612,7 +585,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
activation=NoOpActivation())
activation=helper.NoOpActivation())
norm1 = norm(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
......@@ -633,7 +606,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
depthwise_initializer=self._kernel_initializer,
dilation_rate=dilation_rate,
use_bias=False,
activation=NoOpActivation())
activation=helper.NoOpActivation())
]
kernel_size = (1, 1)
conv_dilation = leading_layers + [
......@@ -645,7 +618,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer,
dilation_rate=dilation_rate,
use_bias=False,
activation=NoOpActivation())
activation=helper.NoOpActivation())
]
norm_dilation = norm(
axis=self._bn_axis,
......@@ -656,16 +629,16 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
if self._pool_kernel_size is None:
pooling = [
_quantize_wrapped_layer(
helper.quantize_wrapped_layer(
tf.keras.layers.GlobalAveragePooling2D,
configs.Default8BitQuantizeConfig([], [], True))(),
_quantize_wrapped_layer(
helper.quantize_wrapped_layer(
tf.keras.layers.Reshape,
configs.Default8BitQuantizeConfig([], [], True))((1, 1, channels))
]
else:
pooling = [
_quantize_wrapped_layer(
helper.quantize_wrapped_layer(
tf.keras.layers.AveragePooling2D,
configs.Default8BitQuantizeConfig([], [],
True))(self._pool_kernel_size)
......@@ -677,7 +650,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
activation=NoOpActivation())
activation=helper.NoOpActivation())
norm2 = norm(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
......@@ -685,7 +658,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
self.aspp_layers.append(pooling + [conv2, norm2])
resizing = _quantize_wrapped_layer(
resizing = helper.quantize_wrapped_layer(
tf.keras.layers.Resizing, configs.Default8BitQuantizeConfig([], [],
True))
self._resizing_layer = resizing(
......@@ -698,14 +671,14 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
activation=NoOpActivation()),
activation=helper.NoOpActivation()),
norm_with_quantize(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
]
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
concat = _quantize_wrapped_layer(
concat = helper.quantize_wrapped_layer(
tf.keras.layers.Concatenate,
configs.Default8BitQuantizeConfig([], [], True))
self._concat_layer = concat(axis=-1)
......
......@@ -13,7 +13,9 @@
# limitations under the License.
"""Quantization helpers."""
from typing import Any, Dict
import tensorflow as tf
import tensorflow_model_optimization as tfmot
......@@ -47,3 +49,37 @@ class LayerQuantizerHelper(object):
for name in self._quantizers:
self._quantizer_vars[name] = self._quantizers[name].build(
tensor_shape=None, name=name, layer=self)
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
if not other or not isinstance(other, NoOpActivation):
return False
return True
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs), quantize_config)
return constructor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment