Commit ce2aea2f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 428078415
parent 94696ab1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Default 8-bit QuantizeConfigs."""
from typing import Sequence, Callable, Tuple, Any, Dict
import tensorflow as tf
import tensorflow_model_optimization as tfmot
Quantizer = tfmot.quantization.keras.quantizers.Quantizer
Layer = tf.keras.layers.Layer
Activation = Callable[[tf.Tensor], tf.Tensor]
WeightAndQuantizer = Tuple[tf.Variable, Quantizer]
ActivationAndQuantizer = Tuple[Activation, Quantizer]
class Default8BitOutputQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
"""QuantizeConfig which only quantizes the output from a layer."""
def get_weights_and_quantizers(
self, layer: Layer) -> Sequence[WeightAndQuantizer]:
return []
def get_activations_and_quantizers(
self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
return []
def set_quantize_weights(self,
layer: Layer,
quantize_weights: Sequence[tf.Tensor]):
pass
def set_quantize_activations(self,
layer: Layer,
quantize_activations: Sequence[Activation]):
pass
def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
return [
tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False)
]
def get_config(self) -> Dict[str, Any]:
return {}
class NoOpQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
"""QuantizeConfig which does not quantize any part of the layer."""
def get_weights_and_quantizers(
self, layer: Layer) -> Sequence[WeightAndQuantizer]:
return []
def get_activations_and_quantizers(
self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
return []
def set_quantize_weights(
self,
layer: Layer,
quantize_weights: Sequence[tf.Tensor]):
pass
def set_quantize_activations(
self,
layer: Layer,
quantize_activations: Sequence[Activation]):
pass
def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
return []
def get_config(self) -> Dict[str, Any]:
return {}
class Default8BitQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
"""QuantizeConfig for non recurrent Keras layers."""
def __init__(self,
weight_attrs: Sequence[str],
activation_attrs: Sequence[str],
quantize_output: bool):
"""Initializes a default 8bit quantize config."""
self.weight_attrs = weight_attrs
self.activation_attrs = activation_attrs
self.quantize_output = quantize_output
# TODO(pulkitb): For some layers such as Conv2D, per_axis should be True.
# Add mapping for which layers support per_axis.
self.weight_quantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer(
num_bits=8, per_axis=False, symmetric=True, narrow_range=True)
self.activation_quantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False)
def get_weights_and_quantizers(
self, layer: Layer) -> Sequence[WeightAndQuantizer]:
"""See base class."""
return [(getattr(layer, weight_attr), self.weight_quantizer)
for weight_attr in self.weight_attrs]
def get_activations_and_quantizers(
self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
"""See base class."""
return [(getattr(layer, activation_attr), self.activation_quantizer)
for activation_attr in self.activation_attrs]
def set_quantize_weights(
self,
layer: Layer,
quantize_weights: Sequence[tf.Tensor]):
"""See base class."""
if len(self.weight_attrs) != len(quantize_weights):
raise ValueError(
'`set_quantize_weights` called on layer {} with {} '
'weight parameters, but layer expects {} values.'.format(
layer.name, len(quantize_weights), len(self.weight_attrs)))
for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
current_weight = getattr(layer, weight_attr)
if current_weight.shape != weight.shape:
raise ValueError('Existing layer weight shape {} is incompatible with'
'provided weight shape {}'.format(
current_weight.shape, weight.shape))
setattr(layer, weight_attr, weight)
def set_quantize_activations(
self,
layer: Layer,
quantize_activations: Sequence[Activation]):
"""See base class."""
if len(self.activation_attrs) != len(quantize_activations):
raise ValueError(
'`set_quantize_activations` called on layer {} with {} '
'activation parameters, but layer expects {} values.'.format(
layer.name, len(quantize_activations),
len(self.activation_attrs)))
for activation_attr, activation in zip(
self.activation_attrs, quantize_activations):
setattr(layer, activation_attr, activation)
def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
"""See base class."""
if self.quantize_output:
return [self.activation_quantizer]
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> object:
"""Instantiates a `Default8BitQuantizeConfig` from its config.
Args:
config: Output of `get_config()`.
Returns:
A `Default8BitQuantizeConfig` instance.
"""
return cls(**config)
def get_config(self) -> Dict[str, Any]:
"""Get a config for this quantize config."""
# TODO(pulkitb): Add weight and activation quantizer to config.
# Currently it's created internally, but ideally the quantizers should be
# part of the constructor and passed in from the registry.
return {
'weight_attrs': self.weight_attrs,
'activation_attrs': self.activation_attrs,
'quantize_output': self.quantize_output
}
def __eq__(self, other):
if not isinstance(other, Default8BitQuantizeConfig):
return False
return (self.weight_attrs == other.weight_attrs and
self.activation_attrs == self.activation_attrs and
self.weight_quantizer == other.weight_quantizer and
self.activation_quantizer == other.activation_quantizer and
self.quantize_output == other.quantize_output)
def __ne__(self, other):
return not self.__eq__(other)
class Default8BitConvWeightsQuantizer(
tfmot.quantization.keras.quantizers.LastValueQuantizer):
"""Quantizer for handling weights in Conv2D/DepthwiseConv2D layers."""
def __init__(self):
"""Construct LastValueQuantizer with params specific for TFLite Convs."""
super(Default8BitConvWeightsQuantizer, self).__init__(
num_bits=8, per_axis=True, symmetric=True, narrow_range=True)
def build(self,
tensor_shape: tf.TensorShape,
name: str,
layer: Layer):
"""Build min/max quantization variables."""
min_weight = layer.add_weight(
name + '_min',
shape=(tensor_shape[-1],),
initializer=tf.keras.initializers.Constant(-6.0),
trainable=False)
max_weight = layer.add_weight(
name + '_max',
shape=(tensor_shape[-1],),
initializer=tf.keras.initializers.Constant(6.0),
trainable=False)
return {'min_var': min_weight, 'max_var': max_weight}
class NoQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
"""Dummy quantizer for explicitly not quantize."""
def __call__(self, inputs, training, weights, **kwargs):
return tf.identity(inputs)
def get_config(self):
return {}
def build(self, tensor_shape, name, layer):
return {}
class Default8BitConvQuantizeConfig(Default8BitQuantizeConfig):
"""QuantizeConfig for Conv2D/DepthwiseConv2D layers."""
def __init__(self,
weight_attrs: Sequence[str],
activation_attrs: Sequence[str],
quantize_output: bool):
"""Initializes default 8bit quantization config for the conv layer."""
super().__init__(weight_attrs, activation_attrs, quantize_output)
self.weight_quantizer = Default8BitConvWeightsQuantizer()
class Default8BitActivationQuantizeConfig(
tfmot.quantization.keras.QuantizeConfig):
"""QuantizeConfig for keras.layers.Activation.
`keras.layers.Activation` needs a separate `QuantizeConfig` since the
decision to quantize depends on the specific activation type.
"""
def _assert_activation_layer(self, layer: Layer):
if not isinstance(layer, tf.keras.layers.Activation):
raise RuntimeError(
'Default8BitActivationQuantizeConfig can only be used with '
'`keras.layers.Activation`.')
def get_weights_and_quantizers(
self, layer: Layer) -> Sequence[WeightAndQuantizer]:
"""See base class."""
self._assert_activation_layer(layer)
return []
def get_activations_and_quantizers(
self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
"""See base class."""
self._assert_activation_layer(layer)
return []
def set_quantize_weights(
self,
layer: Layer,
quantize_weights: Sequence[tf.Tensor]):
"""See base class."""
self._assert_activation_layer(layer)
def set_quantize_activations(
self,
layer: Layer,
quantize_activations: Sequence[Activation]):
"""See base class."""
self._assert_activation_layer(layer)
def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
"""See base class."""
self._assert_activation_layer(layer)
if not hasattr(layer.activation, '__name__'):
raise ValueError('Activation {} not supported by '
'Default8BitActivationQuantizeConfig.'.format(
layer.activation))
# This code is copied from TFMOT repo, but added relu6 to support mobilenet.
if layer.activation.__name__ in ['relu', 'relu6', 'swish', 'hard_swish']:
# 'relu' should generally get fused into the previous layer.
return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]
elif layer.activation.__name__ in [
'linear', 'softmax', 'sigmoid', 'hard_sigmoid'
]:
return []
raise ValueError('Activation {} not supported by '
'Default8BitActivationQuantizeConfig.'.format(
layer.activation))
def get_config(self) -> Dict[str, Any]:
"""Get a config for this quantizer config."""
return {}
def _types_dict():
return {
'Default8BitOutputQuantizeConfig':
Default8BitOutputQuantizeConfig,
'NoOpQuantizeConfig':
NoOpQuantizeConfig,
'Default8BitQuantizeConfig':
Default8BitQuantizeConfig,
'Default8BitConvWeightsQuantizer':
Default8BitConvWeightsQuantizer,
'Default8BitConvQuantizeConfig':
Default8BitConvQuantizeConfig,
'Default8BitActivationQuantizeConfig':
Default8BitActivationQuantizeConfig,
}
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for configs.py."""
# Import libraries
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.quantization import configs
class _TestHelper(object):
def _convert_list(self, list_of_tuples):
"""Transforms a list of 2-tuples to a tuple of 2 lists.
`QuantizeConfig` methods return a list of 2-tuples in the form
[(weight1, quantizer1), (weight2, quantizer2)]. This function converts
it into a 2-tuple of lists. ([weight1, weight2]), (quantizer1, quantizer2).
Args:
list_of_tuples: List of 2-tuples.
Returns:
2-tuple of lists.
"""
list1 = []
list2 = []
for a, b in list_of_tuples:
list1.append(a)
list2.append(b)
return list1, list2
# TODO(pulkitb): Consider asserting on full equality for quantizers.
def _assert_weight_quantizers(self, quantizer_list):
for quantizer in quantizer_list:
self.assertIsInstance(
quantizer,
tfmot.quantization.keras.quantizers.LastValueQuantizer)
def _assert_activation_quantizers(self, quantizer_list):
for quantizer in quantizer_list:
self.assertIsInstance(
quantizer,
tfmot.quantization.keras.quantizers.MovingAverageQuantizer)
def _assert_kernel_equality(self, a, b):
self.assertAllEqual(a.numpy(), b.numpy())
class Default8BitQuantizeConfigTest(tf.test.TestCase, _TestHelper):
def _simple_dense_layer(self):
layer = tf.keras.layers.Dense(2)
layer.build(input_shape=(3,))
return layer
def testGetsQuantizeWeightsAndQuantizers(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
(weights, weight_quantizers) = self._convert_list(
quantize_config.get_weights_and_quantizers(layer))
self._assert_weight_quantizers(weight_quantizers)
self.assertEqual([layer.kernel], weights)
def testGetsQuantizeActivationsAndQuantizers(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
(activations, activation_quantizers) = self._convert_list(
quantize_config.get_activations_and_quantizers(layer))
self._assert_activation_quantizers(activation_quantizers)
self.assertEqual([layer.activation], activations)
def testSetsQuantizeWeights(self):
layer = self._simple_dense_layer()
quantize_kernel = tf.keras.backend.variable(
np.ones(layer.kernel.shape.as_list()))
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
quantize_config.set_quantize_weights(layer, [quantize_kernel])
self._assert_kernel_equality(layer.kernel, quantize_kernel)
def testSetsQuantizeActivations(self):
layer = self._simple_dense_layer()
quantize_activation = tf.keras.activations.relu
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
quantize_config.set_quantize_activations(layer, [quantize_activation])
self.assertEqual(layer.activation, quantize_activation)
def testSetsQuantizeWeights_ErrorOnWrongNumberOfWeights(self):
layer = self._simple_dense_layer()
quantize_kernel = tf.keras.backend.variable(
np.ones(layer.kernel.shape.as_list()))
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
with self.assertRaises(ValueError):
quantize_config.set_quantize_weights(layer, [])
with self.assertRaises(ValueError):
quantize_config.set_quantize_weights(layer,
[quantize_kernel, quantize_kernel])
def testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight(self):
layer = self._simple_dense_layer()
quantize_kernel = tf.keras.backend.variable(np.ones([1, 2]))
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
with self.assertRaises(ValueError):
quantize_config.set_quantize_weights(layer, [quantize_kernel])
def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
layer = self._simple_dense_layer()
quantize_activation = tf.keras.activations.relu
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
with self.assertRaises(ValueError):
quantize_config.set_quantize_activations(layer, [])
with self.assertRaises(ValueError):
quantize_config.set_quantize_activations(
layer, [quantize_activation, quantize_activation])
def testGetsResultQuantizers_ReturnsQuantizer(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
[], [], True)
output_quantizers = quantize_config.get_output_quantizers(layer)
self.assertLen(output_quantizers, 1)
self._assert_activation_quantizers(output_quantizers)
def testGetsResultQuantizers_EmptyWhenFalse(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
[], [], False)
output_quantizers = quantize_config.get_output_quantizers(layer)
self.assertEqual([], output_quantizers)
def testSerialization(self):
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
expected_config = {
'class_name': 'Default8BitQuantizeConfig',
'config': {
'weight_attrs': ['kernel'],
'activation_attrs': ['activation'],
'quantize_output': False
}
}
serialized_quantize_config = tf.keras.utils.serialize_keras_object(
quantize_config)
self.assertEqual(expected_config, serialized_quantize_config)
quantize_config_from_config = tf.keras.utils.deserialize_keras_object(
serialized_quantize_config,
module_objects=globals(),
custom_objects=configs._types_dict())
self.assertEqual(quantize_config, quantize_config_from_config)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantization schemes."""
from typing import Type
# Import libraries
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks
from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers
from official.projects.qat.vision.quantization import configs
keras = tf.keras
default_8bit_transforms = tfmot.quantization.keras.default_8bit.default_8bit_transforms
LayerNode = tfmot.quantization.keras.graph_transformations.transforms.LayerNode
LayerPattern = tfmot.quantization.keras.graph_transformations.transforms.LayerPattern
_QUANTIZATION_WEIGHT_NAMES = [
'output_max', 'output_min', 'optimizer_step',
'kernel_min', 'kernel_max',
'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max']
_ORIGINAL_WEIGHT_NAME = [
'kernel', 'depthwise_kernel',
'gamma', 'beta', 'moving_mean', 'moving_variance',
'bias']
class CustomLayerQuantize(
tfmot.quantization.keras.graph_transformations.transforms.Transform):
"""Add QAT support for Keras Custom layer."""
def __init__(self,
original_layer_pattern: str,
quantized_layer_class: Type[keras.layers.Layer]):
super(CustomLayerQuantize, self).__init__()
self._original_layer_pattern = original_layer_pattern
self._quantized_layer_class = quantized_layer_class
def pattern(self) -> LayerPattern:
"""See base class."""
return LayerPattern(self._original_layer_pattern)
def _is_quantization_weight_name(self, name):
simple_name = name.split('/')[-1].split(':')[0]
if simple_name in _QUANTIZATION_WEIGHT_NAMES:
return True
if simple_name in _ORIGINAL_WEIGHT_NAME:
return False
raise ValueError('Variable name {} is not supported on '
'CustomLayerQuantize({}) transform.'.format(
simple_name,
self._original_layer_pattern))
def replacement(self, match_layer: LayerNode) -> LayerNode:
"""See base class."""
bottleneck_layer = match_layer.layer
bottleneck_config = bottleneck_layer['config']
bottleneck_names_and_weights = list(match_layer.names_and_weights)
quantized_layer = self._quantized_layer_class(
**bottleneck_config)
dummy_input_shape = [1, 64, 128, 1]
# SegmentationHead layer requires a tuple of 2 tensors.
if isinstance(quantized_layer,
quantized_nn_layers.SegmentationHeadQuantized):
dummy_input_shape = ([1, 1, 1, 1], [1, 1, 1, 1])
quantized_layer.compute_output_shape(dummy_input_shape)
quantized_names_and_weights = zip(
[weight.name for weight in quantized_layer.weights],
quantized_layer.get_weights())
match_idx = 0
names_and_weights = []
for name_and_weight in quantized_names_and_weights:
if not self._is_quantization_weight_name(name=name_and_weight[0]):
name_and_weight = bottleneck_names_and_weights[match_idx]
match_idx = match_idx + 1
names_and_weights.append(name_and_weight)
if match_idx != len(bottleneck_names_and_weights):
raise ValueError('{}/{} of Bottleneck weights is transformed.'.format(
match_idx, len(bottleneck_names_and_weights)))
quantized_layer_config = keras.layers.serialize(quantized_layer)
quantized_layer_config['name'] = quantized_layer_config['config']['name']
if bottleneck_layer['class_name'] in [
'Vision>Conv2DBNBlock', 'Vision>InvertedBottleneckBlock',
'Vision>SegmentationHead', 'Vision>SpatialPyramidPooling',
'Vision>ASPP',
# TODO(yeqing): Removes the Beta layers.
'Beta>Conv2DBNBlock', 'Beta>InvertedBottleneckBlock',
'Beta>SegmentationHead', 'Beta>SpatialPyramidPooling', 'Beta>ASPP'
]:
layer_metadata = {'quantize_config': configs.NoOpQuantizeConfig()}
else:
layer_metadata = {
'quantize_config': configs.Default8BitOutputQuantizeConfig()
}
return LayerNode(
quantized_layer_config,
metadata=layer_metadata,
names_and_weights=names_and_weights)
class QuantizeLayoutTransform(
tfmot.quantization.keras.QuantizeLayoutTransform):
"""Default model transformations."""
def apply(self, model, layer_quantize_map):
"""Implement default 8-bit transforms.
Currently this means the following.
1. Pull activations into layers, and apply fuse activations. (TODO)
2. Modify range in incoming layers for Concat. (TODO)
3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.
Args:
model: Keras model to be quantized.
layer_quantize_map: Map with keys as layer names, and values as dicts
containing custom `QuantizeConfig`s which may have been passed with
layers.
Returns:
(Transformed Keras model to better match TensorFlow Lite backend, updated
layer quantize map.)
"""
transforms = [
default_8bit_transforms.InputLayerQuantize(),
default_8bit_transforms.SeparableConv1DQuantize(),
default_8bit_transforms.SeparableConvQuantize(),
default_8bit_transforms.Conv2DReshapeBatchNormReLUQuantize(),
default_8bit_transforms.Conv2DReshapeBatchNormActivationQuantize(),
default_8bit_transforms.Conv2DBatchNormReLUQuantize(),
default_8bit_transforms.Conv2DBatchNormActivationQuantize(),
default_8bit_transforms.Conv2DReshapeBatchNormQuantize(),
default_8bit_transforms.Conv2DBatchNormQuantize(),
default_8bit_transforms.ConcatTransform6Inputs(),
default_8bit_transforms.ConcatTransform5Inputs(),
default_8bit_transforms.ConcatTransform4Inputs(),
default_8bit_transforms.ConcatTransform3Inputs(),
default_8bit_transforms.ConcatTransform(),
default_8bit_transforms.LayerReLUQuantize(),
default_8bit_transforms.LayerReluActivationQuantize(),
CustomLayerQuantize('Vision>BottleneckBlock',
quantized_nn_blocks.BottleneckBlockQuantized),
CustomLayerQuantize(
'Vision>InvertedBottleneckBlock',
quantized_nn_blocks.InvertedBottleneckBlockQuantized),
CustomLayerQuantize('Vision>Conv2DBNBlock',
quantized_nn_blocks.Conv2DBNBlockQuantized),
CustomLayerQuantize('Vision>SegmentationHead',
quantized_nn_layers.SegmentationHeadQuantized),
CustomLayerQuantize('Vision>SpatialPyramidPooling',
quantized_nn_layers.SpatialPyramidPoolingQuantized),
CustomLayerQuantize('Vision>ASPP', quantized_nn_layers.ASPPQuantized),
# TODO(yeqing): Remove the `Beta` components.
CustomLayerQuantize('Beta>BottleneckBlock',
quantized_nn_blocks.BottleneckBlockQuantized),
CustomLayerQuantize(
'Beta>InvertedBottleneckBlock',
quantized_nn_blocks.InvertedBottleneckBlockQuantized),
CustomLayerQuantize('Beta>Conv2DBNBlock',
quantized_nn_blocks.Conv2DBNBlockQuantized),
CustomLayerQuantize('Beta>SegmentationHead',
quantized_nn_layers.SegmentationHeadQuantized),
CustomLayerQuantize('Beta>SpatialPyramidPooling',
quantized_nn_layers.SpatialPyramidPoolingQuantized),
CustomLayerQuantize('Beta>ASPP', quantized_nn_layers.ASPPQuantized)
]
return tfmot.quantization.keras.graph_transformations.model_transformer.ModelTransformer(
model, transforms,
set(layer_quantize_map.keys()), layer_quantize_map).transform()
class Default8BitQuantizeScheme(
tfmot.quantization.keras.default_8bit.Default8BitQuantizeScheme):
def get_layout_transformer(self):
return QuantizeLayoutTransform()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration on qat project."""
# pylint: disable=unused-import
from official.projects.qat.vision import configs
from official.projects.qat.vision.modeling import layers
from official.projects.qat.vision.tasks import image_classification
from official.projects.qat.vision.tasks import retinanet
from official.projects.qat.vision.tasks import semantic_segmentation
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tasks package definition."""
from official.projects.qat.vision.tasks import image_classification
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
import tensorflow as tf
from official.core import task_factory
from official.projects.qat.vision.configs import image_classification as exp_cfg
from official.projects.qat.vision.modeling import factory
from official.vision.beta.tasks import image_classification
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(image_classification.ImageClassificationTask):
"""A task for image classification with QAT."""
def build_model(self) -> tf.keras.Model:
"""Builds classification model with QAT."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = super(ImageClassificationTask, self).build_model()
if self.task_config.quantization:
model = factory.build_qat_classification_model(
model,
self.task_config.quantization,
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for image classification task."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.projects.qat.vision.tasks import image_classification as img_cls_task
from official.vision import beta
class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('resnet_imagenet_qat'),
('mobilenet_imagenet_qat'))
def test_task(self, config_name):
config = exp_factory.get_exp_config(config_name)
config.task.train_data.global_batch_size = 2
task = img_cls_task.ImageClassificationTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
for metric in metrics:
logs[metric.name] = metric.result()
self.assertIn('loss', logs)
self.assertIn('accuracy', logs)
self.assertIn('top_5_accuracy', logs)
logs = task.validation_step(next(iterator), model, metrics=metrics)
for metric in metrics:
logs[metric.name] = metric.result()
self.assertIn('loss', logs)
self.assertIn('accuracy', logs)
self.assertIn('top_5_accuracy', logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet task definition."""
import tensorflow as tf
from official.core import task_factory
from official.projects.qat.vision.configs import retinanet as exp_cfg
from official.projects.qat.vision.modeling import factory
from official.vision.beta.tasks import retinanet
@task_factory.register_task_cls(exp_cfg.RetinaNetTask)
class RetinaNetTask(retinanet.RetinaNetTask):
"""A task for RetinaNet object detection with QAT."""
def build_model(self) -> tf.keras.Model:
"""Builds RetinaNet model with QAT."""
model = super(RetinaNetTask, self).build_model()
if self.task_config.quantization:
model = factory.build_qat_retinanet(
model,
self.task_config.quantization,
model_config=self.task_config.model)
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for RetinaNet task."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.projects.qat.vision.tasks import retinanet
from official.vision import beta
from official.vision.beta.configs import retinanet as exp_cfg
class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('retinanet_spinenet_mobile_coco_qat', True),
('retinanet_spinenet_mobile_coco_qat', False),
)
def test_retinanet_task(self, test_config, is_training):
"""RetinaNet task test for training and val using toy configs."""
config = exp_factory.get_exp_config(test_config)
# modify config to suit local testing
config.task.model.input_size = [128, 128, 3]
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 1
config.task.validation_data.global_batch_size = 1
config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2
config.train_steps = 1
task = retinanet.RetinaNetTask(config.task)
model = task.build_model()
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Semantic segmentation task definition."""
import tensorflow as tf
from official.core import task_factory
from official.projects.qat.vision.configs import semantic_segmentation as exp_cfg
from official.projects.qat.vision.modeling import factory
from official.vision.beta.tasks import semantic_segmentation
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class SemanticSegmentationTask(semantic_segmentation.SemanticSegmentationTask):
"""A task for semantic segmentation with QAT."""
def build_model(self) -> tf.keras.Model:
"""Builds semantic segmentation model with QAT."""
model = super().build_model()
input_specs = tf.keras.layers.InputSpec(shape=[None] +
self.task_config.model.input_size)
if self.task_config.quantization:
model = factory.build_qat_segmentation_model(
model, self.task_config.quantization, input_specs)
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for semantic segmentation task."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.projects.qat.vision.tasks import semantic_segmentation
from official.vision import beta
from official.vision.beta.configs import semantic_segmentation as exp_cfg
class SemanticSegmentationTaskTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('mnv2_deeplabv3_pascal_qat', True),
('mnv2_deeplabv3_pascal_qat', False),
)
def test_semantic_segmentation_task(self, test_config, is_training):
"""Semantic segmentation task test for training and val using toy configs."""
config = exp_factory.get_exp_config(test_config)
# modify config to suit local testing
config.task.model.input_size = [512, 512, 3]
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 1
config.task.validation_data.global_batch_size = 1
config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2
config.train_steps = 1
config.task.model.decoder.aspp.output_tensor = True
task = semantic_segmentation.SemanticSegmentationTask(config.task)
model = task.build_model()
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including QAT configs.."""
from absl import app
from official.common import flags as tfm_flags
from official.projects.qat.vision import registry_imports # pylint: disable=unused-import
from official.vision.beta import train
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(train.main)
......@@ -148,7 +148,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig:
# PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES = 10582
PASCAL_VAL_EXAMPLES = 1449
PASCAL_INPUT_PATH_BASE = 'pascal_voc_seg'
PASCAL_INPUT_PATH_BASE = 'gs://**/pascal_voc_seg'
@exp_factory.register_config_factory('seg_deeplabv3_pascal')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment