Commit ccbac18b authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Remove duplicate functions and classes in QAT project

PiperOrigin-RevId: 437133811
parent ceadccbe
...@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot ...@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils from official.modeling import tf_utils
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.quantization import configs from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
from official.vision.modeling.layers import nn_layers from official.vision.modeling.layers import nn_layers
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
if not other or not isinstance(other, NoOpActivation):
return False
return True
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def _quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs),
quantize_config)
return constructor
# This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply # This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply
# QAT. # QAT.
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if use_sync_bn: if use_sync_bn:
self._norm = _quantize_wrapped_layer( self._norm = helper.quantize_wrapped_layer(
tf.keras.layers.experimental.SyncBatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization,
configs.NoOpQuantizeConfig()) configs.NoOpQuantizeConfig())
self._norm_with_quantize = _quantize_wrapped_layer( self._norm_with_quantize = helper.quantize_wrapped_layer(
tf.keras.layers.experimental.SyncBatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization,
configs.Default8BitOutputQuantizeConfig()) configs.Default8BitOutputQuantizeConfig())
else: else:
self._norm = _quantize_wrapped_layer( self._norm = helper.quantize_wrapped_layer(
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization, configs.NoOpQuantizeConfig())
configs.NoOpQuantizeConfig()) self._norm_with_quantize = helper.quantize_wrapped_layer(
self._norm_with_quantize = _quantize_wrapped_layer(
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
configs.Default8BitOutputQuantizeConfig()) configs.Default8BitOutputQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]): def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling.""" """Build variables and child layers to prepare for calling."""
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], False)) False))
if self._use_projection: if self._use_projection:
if self._resnetd_shortcut: if self._resnetd_shortcut:
self._shortcut0 = tf.keras.layers.AveragePooling2D( self._shortcut0 = tf.keras.layers.AveragePooling2D(
...@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
else: else:
self._shortcut = conv2d_quantized( self._shortcut = conv2d_quantized(
filters=self._filters * 4, filters=self._filters * 4,
...@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm0 = self._norm_with_quantize( self._norm0 = self._norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
...@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm3 = self._norm_with_quantize( self._norm3 = self._norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -392,9 +359,9 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer): ...@@ -392,9 +359,9 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
norm_layer = ( norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization tf.keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf.keras.layers.BatchNormalization) if use_sync_bn else tf.keras.layers.BatchNormalization)
self._norm_with_quantize = _quantize_wrapped_layer( self._norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig()) norm_layer, configs.Default8BitOutputQuantizeConfig())
self._norm = _quantize_wrapped_layer(norm_layer, self._norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig()) configs.NoOpQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer): ...@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
if self._use_explicit_padding and self._kernel_size > 1: if self._use_explicit_padding and self._kernel_size > 1:
padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size) padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size)
self._pad = tf.keras.layers.ZeroPadding2D(padding_size) self._pad = tf.keras.layers.ZeroPadding2D(padding_size)
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], not self._use_normalization)) not self._use_normalization))
self._conv0 = conv2d_quantized( self._conv0 = conv2d_quantized(
filters=self._filters, filters=self._filters,
kernel_size=self._kernel_size, kernel_size=self._kernel_size,
...@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer): ...@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
if self._use_normalization: if self._use_normalization:
self._norm0 = self._norm_by_activation(self._activation)( self._norm0 = self._norm_by_activation(self._activation)(
axis=self._bn_axis, axis=self._bn_axis,
...@@ -579,9 +546,9 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -579,9 +546,9 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
norm_layer = ( norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization tf.keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf.keras.layers.BatchNormalization) if use_sync_bn else tf.keras.layers.BatchNormalization)
self._norm_with_quantize = _quantize_wrapped_layer( self._norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig()) norm_layer, configs.Default8BitOutputQuantizeConfig())
self._norm = _quantize_wrapped_layer(norm_layer, self._norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig()) configs.NoOpQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]): def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling.""" """Build variables and child layers to prepare for calling."""
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], False)) False))
depthwise_conv2d_quantized = _quantize_wrapped_layer( depthwise_conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D, tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['depthwise_kernel'], ['activation'], False)) ['activation'], False))
expand_filters = self._in_filters expand_filters = self._in_filters
if self._expand_ratio > 1: if self._expand_ratio > 1:
# First 1x1 conv for channel expansion. # First 1x1 conv for channel expansion.
...@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm0 = self._norm_by_activation(self._activation)( self._norm0 = self._norm_by_activation(self._activation)(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
depthwise_initializer=self._kernel_initializer, depthwise_initializer=self._kernel_initializer,
depthwise_regularizer=self._depthsize_regularizer, depthwise_regularizer=self._depthsize_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm1 = self._norm_by_activation(self._depthwise_activation)( self._norm1 = self._norm_by_activation(self._depthwise_activation)(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm2 = self._norm_with_quantize( self._norm2 = self._norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -31,36 +31,6 @@ States = Dict[str, tf.Tensor] ...@@ -31,36 +31,6 @@ States = Dict[str, tf.Tensor]
Activation = Union[str, Callable] Activation = Union[str, Callable]
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
return isinstance(other, NoOpActivation)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def _quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs),
quantize_config)
return constructor
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitationQuantized( class SqueezeExcitationQuantized(
helper.LayerQuantizerHelper, helper.LayerQuantizerHelper,
...@@ -154,14 +124,13 @@ class SqueezeExcitationQuantized( ...@@ -154,14 +124,13 @@ class SqueezeExcitationQuantized(
return x return x
def build(self, input_shape): def build(self, input_shape):
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], False)) False))
conv2d_quantized_output_quantized = _quantize_wrapped_layer( conv2d_quantized_output_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], True))
['kernel'], ['activation'], True))
num_reduced_filters = nn_layers.make_divisible( num_reduced_filters = nn_layers.make_divisible(
max(1, int(self._in_filters * self._se_ratio)), max(1, int(self._in_filters * self._se_ratio)),
divisor=self._divisible_by, divisor=self._divisible_by,
...@@ -176,7 +145,7 @@ class SqueezeExcitationQuantized( ...@@ -176,7 +145,7 @@ class SqueezeExcitationQuantized(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._se_expand = conv2d_quantized_output_quantized( self._se_expand = conv2d_quantized_output_quantized(
filters=self._out_filters, filters=self._out_filters,
...@@ -187,7 +156,7 @@ class SqueezeExcitationQuantized( ...@@ -187,7 +156,7 @@ class SqueezeExcitationQuantized(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._multiply = tfmot.quantization.keras.QuantizeWrapperV2( self._multiply = tfmot.quantization.keras.QuantizeWrapperV2(
tf.keras.layers.Multiply(), tf.keras.layers.Multiply(),
...@@ -342,14 +311,14 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -342,14 +311,14 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
backbone_shape = input_shape[0] backbone_shape = input_shape[0]
use_depthwise_convolution = self._config_dict['use_depthwise_convolution'] use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
random_initializer = tf.keras.initializers.RandomNormal(stddev=0.01) random_initializer = tf.keras.initializers.RandomNormal(stddev=0.01)
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False)) False))
conv2d_quantized_output_quantized = _quantize_wrapped_layer( conv2d_quantized_output_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], True)) configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], True))
depthwise_conv2d_quantized = _quantize_wrapped_layer( depthwise_conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D, tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig(['depthwise_kernel'], configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['activation'], False)) ['activation'], False))
...@@ -365,11 +334,13 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -365,11 +334,13 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
tf.keras.layers.experimental.SyncBatchNormalization tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn'] else if self._config_dict['use_sync_bn'] else
tf.keras.layers.BatchNormalization) tf.keras.layers.BatchNormalization)
norm_with_quantize = _quantize_wrapped_layer( norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig()) norm_layer, configs.Default8BitOutputQuantizeConfig())
norm = norm_with_quantize if self._config_dict['activation'] not in [ if self._config_dict['activation'] not in ['relu', 'relu6']:
'relu', 'relu6' norm = norm_with_quantize
] else _quantize_wrapped_layer(norm_layer, configs.NoOpQuantizeConfig()) else:
norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
bn_kwargs = { bn_kwargs = {
'axis': self._bn_axis, 'axis': self._bn_axis,
...@@ -387,7 +358,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -387,7 +358,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv', name='segmentation_head_deeplabv3p_fusion_conv',
filters=self._config_dict['low_level_num_filters'], filters=self._config_dict['low_level_num_filters'],
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._dlv3p_norm = norm( self._dlv3p_norm = norm(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs) name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
...@@ -406,7 +377,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -406,7 +377,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
depthwise_initializer=random_initializer, depthwise_initializer=random_initializer,
depthwise_regularizer=self._config_dict['kernel_regularizer'], depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1, depth_multiplier=1,
activation=NoOpActivation())) activation=helper.NoOpActivation()))
norm_name = 'segmentation_head_depthwise_norm_{}'.format(i) norm_name = 'segmentation_head_depthwise_norm_{}'.format(i)
self._norms.append(norm(name=norm_name, **bn_kwargs)) self._norms.append(norm(name=norm_name, **bn_kwargs))
conv_name = 'segmentation_head_conv_{}'.format(i) conv_name = 'segmentation_head_conv_{}'.format(i)
...@@ -414,7 +385,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -414,7 +385,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
conv2d_quantized( conv2d_quantized(
name=conv_name, name=conv_name,
filters=self._config_dict['num_filters'], filters=self._config_dict['num_filters'],
activation=NoOpActivation(), activation=helper.NoOpActivation(),
**conv_kwargs)) **conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i) norm_name = 'segmentation_head_norm_{}'.format(i)
self._norms.append(norm(name=norm_name, **bn_kwargs)) self._norms.append(norm(name=norm_name, **bn_kwargs))
...@@ -428,9 +399,9 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -428,9 +399,9 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'], bias_regularizer=self._config_dict['bias_regularizer'],
activation=NoOpActivation()) activation=helper.NoOpActivation())
upsampling = _quantize_wrapped_layer( upsampling = helper.quantize_wrapped_layer(
tf.keras.layers.UpSampling2D, tf.keras.layers.UpSampling2D,
configs.Default8BitQuantizeConfig([], [], True)) configs.Default8BitQuantizeConfig([], [], True))
self._upsampling_layer = upsampling( self._upsampling_layer = upsampling(
...@@ -440,7 +411,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer): ...@@ -440,7 +411,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
self._resizing_layer = tf.keras.layers.Resizing( self._resizing_layer = tf.keras.layers.Resizing(
backbone_shape[1], backbone_shape[2], interpolation='bilinear') backbone_shape[1], backbone_shape[2], interpolation='bilinear')
concat = _quantize_wrapped_layer( concat = helper.quantize_wrapped_layer(
tf.keras.layers.Concatenate, tf.keras.layers.Concatenate,
configs.Default8BitQuantizeConfig([], [], True)) configs.Default8BitQuantizeConfig([], [], True))
self._concat_layer = concat(axis=self._bn_axis) self._concat_layer = concat(axis=self._bn_axis)
...@@ -589,17 +560,19 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -589,17 +560,19 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
norm_layer = ( norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization tf.keras.layers.experimental.SyncBatchNormalization
if self._use_sync_bn else tf.keras.layers.BatchNormalization) if self._use_sync_bn else tf.keras.layers.BatchNormalization)
norm_with_quantize = _quantize_wrapped_layer( norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig()) norm_layer, configs.Default8BitOutputQuantizeConfig())
norm = norm_with_quantize if self._activation not in [ if self._activation not in ['relu', 'relu6']:
'relu', 'relu6' norm = norm_with_quantize
] else _quantize_wrapped_layer(norm_layer, configs.NoOpQuantizeConfig()) else:
norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig())
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'], configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
False)) False))
depthwise_conv2d_quantized_output_quantized = _quantize_wrapped_layer( depthwise_conv2d_quantized_output_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D, tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig(['depthwise_kernel'], configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['activation'], True)) ['activation'], True))
...@@ -612,7 +585,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -612,7 +585,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_bias=False, use_bias=False,
activation=NoOpActivation()) activation=helper.NoOpActivation())
norm1 = norm( norm1 = norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
...@@ -633,7 +606,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -633,7 +606,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
depthwise_initializer=self._kernel_initializer, depthwise_initializer=self._kernel_initializer,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
use_bias=False, use_bias=False,
activation=NoOpActivation()) activation=helper.NoOpActivation())
] ]
kernel_size = (1, 1) kernel_size = (1, 1)
conv_dilation = leading_layers + [ conv_dilation = leading_layers + [
...@@ -645,7 +618,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -645,7 +618,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
use_bias=False, use_bias=False,
activation=NoOpActivation()) activation=helper.NoOpActivation())
] ]
norm_dilation = norm( norm_dilation = norm(
axis=self._bn_axis, axis=self._bn_axis,
...@@ -656,16 +629,16 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -656,16 +629,16 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
if self._pool_kernel_size is None: if self._pool_kernel_size is None:
pooling = [ pooling = [
_quantize_wrapped_layer( helper.quantize_wrapped_layer(
tf.keras.layers.GlobalAveragePooling2D, tf.keras.layers.GlobalAveragePooling2D,
configs.Default8BitQuantizeConfig([], [], True))(), configs.Default8BitQuantizeConfig([], [], True))(),
_quantize_wrapped_layer( helper.quantize_wrapped_layer(
tf.keras.layers.Reshape, tf.keras.layers.Reshape,
configs.Default8BitQuantizeConfig([], [], True))((1, 1, channels)) configs.Default8BitQuantizeConfig([], [], True))((1, 1, channels))
] ]
else: else:
pooling = [ pooling = [
_quantize_wrapped_layer( helper.quantize_wrapped_layer(
tf.keras.layers.AveragePooling2D, tf.keras.layers.AveragePooling2D,
configs.Default8BitQuantizeConfig([], [], configs.Default8BitQuantizeConfig([], [],
True))(self._pool_kernel_size) True))(self._pool_kernel_size)
...@@ -677,7 +650,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -677,7 +650,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_bias=False, use_bias=False,
activation=NoOpActivation()) activation=helper.NoOpActivation())
norm2 = norm( norm2 = norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
...@@ -685,7 +658,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -685,7 +658,7 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
self.aspp_layers.append(pooling + [conv2, norm2]) self.aspp_layers.append(pooling + [conv2, norm2])
resizing = _quantize_wrapped_layer( resizing = helper.quantize_wrapped_layer(
tf.keras.layers.Resizing, configs.Default8BitQuantizeConfig([], [], tf.keras.layers.Resizing, configs.Default8BitQuantizeConfig([], [],
True)) True))
self._resizing_layer = resizing( self._resizing_layer = resizing(
...@@ -698,14 +671,14 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling): ...@@ -698,14 +671,14 @@ class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_bias=False, use_bias=False,
activation=NoOpActivation()), activation=helper.NoOpActivation()),
norm_with_quantize( norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon) epsilon=self._batchnorm_epsilon)
] ]
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
concat = _quantize_wrapped_layer( concat = helper.quantize_wrapped_layer(
tf.keras.layers.Concatenate, tf.keras.layers.Concatenate,
configs.Default8BitQuantizeConfig([], [], True)) configs.Default8BitQuantizeConfig([], [], True))
self._concat_layer = concat(axis=-1) self._concat_layer = concat(axis=-1)
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
"""Quantization helpers.""" """Quantization helpers."""
from typing import Any, Dict
import tensorflow as tf
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
...@@ -47,3 +49,37 @@ class LayerQuantizerHelper(object): ...@@ -47,3 +49,37 @@ class LayerQuantizerHelper(object):
for name in self._quantizers: for name in self._quantizers:
self._quantizer_vars[name] = self._quantizers[name].build( self._quantizer_vars[name] = self._quantizers[name].build(
tensor_shape=None, name=name, layer=self) tensor_shape=None, name=name, layer=self)
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
if not other or not isinstance(other, NoOpActivation):
return False
return True
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs), quantize_config)
return constructor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment