Commit 1051697d authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Add decoder part quantization flag for the object detection model.

PiperOrigin-RevId: 448858109
parent 2aad04b0
...@@ -30,6 +30,8 @@ class Quantization(hyperparams.Config): ...@@ -30,6 +30,8 @@ class Quantization(hyperparams.Config):
change_num_bits: A `bool` indicates whether to manually allocate num_bits. change_num_bits: A `bool` indicates whether to manually allocate num_bits.
num_bits_weight: An `int` number of bits for weight. Default to 8. num_bits_weight: An `int` number of bits for weight. Default to 8.
num_bits_activation: An `int` number of bits for activation. Default to 8. num_bits_activation: An `int` number of bits for activation. Default to 8.
quantize_detection_decoder: A `bool` indicates whether to quantize detection
decoder. It only works for detection model.
quantize_detection_head: A `bool` indicates whether to quantize detection quantize_detection_head: A `bool` indicates whether to quantize detection
head. It only works for detection model. head. It only works for detection model.
""" """
...@@ -37,4 +39,5 @@ class Quantization(hyperparams.Config): ...@@ -37,4 +39,5 @@ class Quantization(hyperparams.Config):
change_num_bits: bool = False change_num_bits: bool = False
num_bits_weight: int = 8 num_bits_weight: int = 8
num_bits_activation: int = 8 num_bits_activation: int = 8
quantize_detection_decoder: bool = False
quantize_detection_head: bool = False quantize_detection_head: bool = False
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 18.4 (Evaluated on the TFLite after conversion.)
# QAT only supports float32 tpu due to fake-quant op.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
losses:
l2_weight_decay: 3.0e-05
model:
anchor:
anchor_size: 3
aspect_ratios: [0.5, 1.0, 2.0]
num_scales: 3
backbone:
mobilenet:
model_id: 'MobileNetV2'
filter_size_scale: 1.0
type: 'mobilenet'
decoder:
type: 'fpn'
fpn:
num_filters: 128
use_separable_conv: true
use_keras_layer: true
head:
num_convs: 4
num_filters: 128
use_separable_conv: true
input_size: [256, 256, 3]
max_level: 7
min_level: 3
norm_activation:
activation: 'relu6'
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
train_data:
dtype: 'float32'
global_batch_size: 256
is_training: true
parser:
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
dtype: 'float32'
global_batch_size: 16
is_training: false
quantization:
pretrained_original_checkpoint: 'gs://**/tf2_mobilenetv2_ssd_jul29/28129005/ckpt-277200'
quantize_detection_decoder: true
quantize_detection_head: true
trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [263340, 272580]
values: [0.032, 0.0032, 0.00032]
type: 'stepwise'
warmup:
linear:
warmup_learning_rate: 0.00067
warmup_steps: 2000
steps_per_loop: 462
train_steps: 277200
validation_interval: 462
validation_steps: 625
...@@ -21,13 +21,16 @@ import tensorflow_model_optimization as tfmot ...@@ -21,13 +21,16 @@ import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model
from official.projects.qat.vision.modeling.heads import dense_prediction_heads as dense_prediction_heads_qat from official.projects.qat.vision.modeling.heads import dense_prediction_heads as dense_prediction_heads_qat
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.n_bit import schemes as n_bit_schemes from official.projects.qat.vision.n_bit import schemes as n_bit_schemes
from official.projects.qat.vision.quantization import configs as qat_configs
from official.projects.qat.vision.quantization import helper from official.projects.qat.vision.quantization import helper
from official.projects.qat.vision.quantization import schemes from official.projects.qat.vision.quantization import schemes
from official.vision import configs from official.vision import configs
from official.vision.modeling import classification_model from official.vision.modeling import classification_model
from official.vision.modeling import retinanet_model from official.vision.modeling import retinanet_model
from official.vision.modeling.decoders import aspp from official.vision.modeling.decoders import aspp
from official.vision.modeling.decoders import fpn
from official.vision.modeling.heads import dense_prediction_heads from official.vision.modeling.heads import dense_prediction_heads
from official.vision.modeling.heads import segmentation_heads from official.vision.modeling.heads import segmentation_heads
from official.vision.modeling.layers import nn_layers from official.vision.modeling.layers import nn_layers
...@@ -120,6 +123,16 @@ def build_qat_classification_model( ...@@ -120,6 +123,16 @@ def build_qat_classification_model(
return optimized_model return optimized_model
def _clone_function_for_fpn(layer):
if isinstance(layer, (
tf.keras.layers.BatchNormalization,
tf.keras.layers.experimental.SyncBatchNormalization)):
return tfmot.quantization.keras.quantize_annotate_layer(
qat_nn_layers.BatchNormalizationWrapper(layer),
qat_configs.Default8BitOutputQuantizeConfig())
return layer
def build_qat_retinanet( def build_qat_retinanet(
model: tf.keras.Model, quantization: common.Quantization, model: tf.keras.Model, quantization: common.Quantization,
model_config: configs.retinanet.RetinaNet) -> tf.keras.Model: model_config: configs.retinanet.RetinaNet) -> tf.keras.Model:
...@@ -144,6 +157,7 @@ def build_qat_retinanet( ...@@ -144,6 +157,7 @@ def build_qat_retinanet(
scope_dict = { scope_dict = {
'L2': tf.keras.regularizers.l2, 'L2': tf.keras.regularizers.l2,
'BatchNormalizationWrapper': qat_nn_layers.BatchNormalizationWrapper,
} }
with tfmot.quantization.keras.quantize_scope(scope_dict): with tfmot.quantization.keras.quantize_scope(scope_dict):
annotated_backbone = tfmot.quantization.keras.quantize_annotate_model( annotated_backbone = tfmot.quantization.keras.quantize_annotate_model(
...@@ -151,6 +165,17 @@ def build_qat_retinanet( ...@@ -151,6 +165,17 @@ def build_qat_retinanet(
optimized_backbone = tfmot.quantization.keras.quantize_apply( optimized_backbone = tfmot.quantization.keras.quantize_apply(
annotated_backbone, annotated_backbone,
scheme=schemes.Default8BitQuantizeScheme()) scheme=schemes.Default8BitQuantizeScheme())
decoder = model.decoder
if quantization.quantize_detection_decoder:
if not isinstance(decoder, fpn.FPN):
raise ValueError('Currently only supports FPN.')
decoder = tf.keras.models.clone_model(
decoder,
clone_function=_clone_function_for_fpn,
)
decoder = tfmot.quantization.keras.quantize_model(decoder)
head = model.head head = model.head
if quantization.quantize_detection_head: if quantization.quantize_detection_head:
if not isinstance(head, dense_prediction_heads.RetinaNetHead): if not isinstance(head, dense_prediction_heads.RetinaNetHead):
...@@ -161,7 +186,7 @@ def build_qat_retinanet( ...@@ -161,7 +186,7 @@ def build_qat_retinanet(
optimized_model = retinanet_model.RetinaNetModel( optimized_model = retinanet_model.RetinaNetModel(
optimized_backbone, optimized_backbone,
model.decoder, decoder,
head, head,
model.detection_generator, model.detection_generator,
min_level=model_config.min_level, min_level=model_config.min_level,
......
...@@ -28,6 +28,7 @@ from official.vision.configs import image_classification as classification_cfg ...@@ -28,6 +28,7 @@ from official.vision.configs import image_classification as classification_cfg
from official.vision.configs import retinanet as retinanet_cfg from official.vision.configs import retinanet as retinanet_cfg
from official.vision.configs import semantic_segmentation as semantic_segmentation_cfg from official.vision.configs import semantic_segmentation as semantic_segmentation_cfg
from official.vision.modeling import factory from official.vision.modeling import factory
from official.vision.modeling.decoders import fpn
from official.vision.modeling.heads import dense_prediction_heads from official.vision.modeling.heads import dense_prediction_heads
...@@ -69,61 +70,99 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -69,61 +70,99 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
('spinenet_mobile', (640, 640), False, False), ('spinenet_mobile', 'identity', (640, 640), False, False),
('spinenet_mobile', (640, 640), False, True), ('spinenet_mobile', 'identity', (640, 640), True, False),
('mobilenet', 'fpn', (640, 640), True, False),
('mobilenet', 'fpn', (640, 640), True, True),
) )
def test_builder(self, def test_builder(self,
backbone_type, backbone_type,
decoder_type,
input_size, input_size,
has_attribute_heads, quantize_detection_head,
quantize_detection_head): quantize_detection_decoder):
num_classes = 2 num_classes = 2
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
if has_attribute_heads:
attribute_heads_config = [ if backbone_type == 'spinenet_mobile':
retinanet_cfg.AttributeHead(name='att1'), backbone_config = backbones.Backbone(
retinanet_cfg.AttributeHead( type=backbone_type,
name='att2', type='classification', size=2), spinenet_mobile=backbones.SpineNetMobile(
] model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7,
use_keras_upsampling_2d=True))
elif backbone_type == 'mobilenet':
backbone_config = backbones.Backbone(
type=backbone_type,
mobilenet=backbones.MobileNet(
model_id='MobileNetV2',
filter_size_scale=1.0))
else:
raise ValueError(
'backbone_type {} is not supported'.format(backbone_type))
if decoder_type == 'identity':
decoder_config = decoders.Decoder(type=decoder_type)
elif decoder_type == 'fpn':
decoder_config = decoders.Decoder(
type=decoder_type,
fpn=decoders.FPN(
num_filters=128,
use_separable_conv=True,
use_keras_layer=True))
else: else:
attribute_heads_config = None raise ValueError(
'decoder_type {} is not supported'.format(decoder_type))
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
num_classes=num_classes, num_classes=num_classes,
input_size=[input_size[0], input_size[1], 3], input_size=[input_size[0], input_size[1], 3],
backbone=backbones.Backbone( backbone=backbone_config,
type=backbone_type, decoder=decoder_config,
spinenet_mobile=backbones.SpineNetMobile(
model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7,
use_keras_upsampling_2d=True)),
head=retinanet_cfg.RetinaNetHead( head=retinanet_cfg.RetinaNetHead(
attribute_heads=attribute_heads_config, attribute_heads=None,
use_separable_conv=True)) use_separable_conv=True))
l2_regularizer = tf.keras.regularizers.l2(5e-5) l2_regularizer = tf.keras.regularizers.l2(5e-5)
quantization_config = common.Quantization( # Build the original float32 retinanet model.
quantize_detection_head=quantize_detection_head)
model = factory.build_retinanet( model = factory.build_retinanet(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
# Call the model with dummy input to build the head part.
dummpy_input = tf.zeros([1] + model_config.input_size)
model(dummpy_input, training=True)
# Build the QAT model from the original model with quantization config.
qat_model = qat_factory.build_qat_retinanet( qat_model = qat_factory.build_qat_retinanet(
model=model, model=model,
quantization=quantization_config, quantization=common.Quantization(
quantize_detection_decoder=quantize_detection_decoder,
quantize_detection_head=quantize_detection_head),
model_config=model_config) model_config=model_config)
if has_attribute_heads:
self.assertEqual(model_config.head.attribute_heads[0].as_dict(), if quantize_detection_head:
dict(name='att1', type='regression', size=1)) # head become a RetinaNetHeadQuantized when we apply quantization.
self.assertEqual(model_config.head.attribute_heads[1].as_dict(), self.assertIsInstance(qat_model.head,
dict(name='att2', type='classification', size=2)) qat_dense_prediction_heads.RetinaNetHeadQuantized)
self.assertIsInstance( else:
qat_model.head, # head is a RetinaNetHead if we don't apply quantization on head part.
(qat_dense_prediction_heads.RetinaNetHeadQuantized self.assertIsInstance(
if quantize_detection_head else qat_model.head, dense_prediction_heads.RetinaNetHead)
dense_prediction_heads.RetinaNetHead)) self.assertNotIsInstance(
qat_model.head, qat_dense_prediction_heads.RetinaNetHeadQuantized)
if decoder_type == 'FPN':
if quantize_detection_decoder:
# FPN decoder become a general keras functional model after applying
# quantization.
self.assertNotIsInstance(qat_model.decoder, fpn.FPN)
else:
self.assertIsInstance(qat_model.decoder, fpn.FPN)
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
import enum import enum
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union, Any
import tensorflow as tf import tensorflow as tf
...@@ -767,3 +767,23 @@ class ASPPQuantized(aspp.ASPP): ...@@ -767,3 +767,23 @@ class ASPPQuantized(aspp.ASPP):
level = str(self._config_dict['level']) level = str(self._config_dict['level'])
backbone_output = inputs[level] if isinstance(inputs, dict) else inputs backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
return self.aspp(backbone_output) return self.aspp(backbone_output)
class BatchNormalizationWrapper(tf.keras.layers.Wrapper):
"""A BatchNormalizationWrapper that explicitly not folded.
It just added an identity depthwise conv right before the normalization.
As a result, given normalization op just folded into the identity depthwise
conv layer.
Note that it only used when the batch normalization folding is not working.
It makes quantize them as a 1x1 depthwise conv layer that just work as same
as inference mode for the normalization. (Basically mult and add for the BN.)
"""
def call(self, inputs: tf.Tensor, *args: Any, **kwargs: Any) -> tf.Tensor:
channels = tf.shape(inputs)[-1]
x = tf.nn.depthwise_conv2d(
inputs, tf.ones([1, 1, channels, 1]), [1, 1, 1, 1], 'VALID')
outputs = self.layer.call(x, *args, **kwargs)
return outputs
...@@ -91,6 +91,17 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -91,6 +91,17 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([2, input_size, input_size, num_filters], self.assertAllEqual([2, input_size, input_size, num_filters],
feats.shape.as_list()) feats.shape.as_list())
@parameterized.parameters(False, True)
def test_bnorm_wrapper_creation(self, use_sync_bn):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
if use_sync_bn:
norm = tf.keras.layers.experimental.SyncBatchNormalization(axis=-1)
else:
norm = tf.keras.layers.BatchNormalization(axis=-1)
layer = nn_layers.BatchNormalizationWrapper(norm)
output = layer(inputs)
self.assertAllEqual([None, 64, 64, 128], output.shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -27,7 +27,9 @@ _QUANTIZATION_WEIGHT_NAMES = [ ...@@ -27,7 +27,9 @@ _QUANTIZATION_WEIGHT_NAMES = [
'depthwise_kernel_min', 'depthwise_kernel_max', 'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max', 'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max',
'quantize_layer_min', 'quantize_layer_max', 'quantize_layer_min', 'quantize_layer_max',
'quantize_layer_1_min', 'quantize_layer_1_max',
'quantize_layer_2_min', 'quantize_layer_2_max', 'quantize_layer_2_min', 'quantize_layer_2_max',
'quantize_layer_3_min', 'quantize_layer_3_max',
'post_activation_min', 'post_activation_max', 'post_activation_min', 'post_activation_max',
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment