Commit 4e1e0a22 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Add weight copy logic for the head part of the object detection model.

PiperOrigin-RevId: 446826259
parent 7b5f980d
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 22.0
# COCO mAP: 23.2
# QAT only supports float32 tpu due to fake-quant op.
runtime:
distribution_strategy: 'tpu'
......
......@@ -22,6 +22,7 @@ from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model
from official.projects.qat.vision.modeling.heads import dense_prediction_heads as dense_prediction_heads_qat
from official.projects.qat.vision.n_bit import schemes as n_bit_schemes
from official.projects.qat.vision.quantization import helper
from official.projects.qat.vision.quantization import schemes
from official.vision import configs
from official.vision.modeling import classification_model
......@@ -157,6 +158,7 @@ def build_qat_retinanet(
head = (
dense_prediction_heads_qat.RetinaNetHeadQuantized.from_config(
head.get_config()))
optimized_model = retinanet_model.RetinaNetModel(
optimized_backbone,
model.decoder,
......@@ -167,6 +169,12 @@ def build_qat_retinanet(
num_scales=model_config.anchor.num_scales,
aspect_ratios=model_config.anchor.aspect_ratios,
anchor_size=model_config.anchor.anchor_size)
if quantization.quantize_detection_head:
# Call the model with dummy input to build the head part.
dummpy_input = tf.zeros([1] + model_config.input_size)
optimized_model(dummpy_input, training=True)
helper.copy_original_weights(model.head, optimized_model.head)
return optimized_model
......
......@@ -21,12 +21,14 @@ import tensorflow as tf
from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import factory as qat_factory
from official.projects.qat.vision.modeling.heads import dense_prediction_heads as qat_dense_prediction_heads
from official.vision.configs import backbones
from official.vision.configs import decoders
from official.vision.configs import image_classification as classification_cfg
from official.vision.configs import retinanet as retinanet_cfg
from official.vision.configs import semantic_segmentation as semantic_segmentation_cfg
from official.vision.modeling import factory
from official.vision.modeling.heads import dense_prediction_heads
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......@@ -67,9 +69,14 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('spinenet_mobile', (640, 640), False),
('spinenet_mobile', (640, 640), False, False),
('spinenet_mobile', (640, 640), False, True),
)
def test_builder(self, backbone_type, input_size, has_attribute_heads):
def test_builder(self,
backbone_type,
input_size,
has_attribute_heads,
quantize_detection_head):
num_classes = 2
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
......@@ -83,6 +90,7 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads_config = None
model_config = retinanet_cfg.RetinaNet(
num_classes=num_classes,
input_size=[input_size[0], input_size[1], 3],
backbone=backbones.Backbone(
type=backbone_type,
spinenet_mobile=backbones.SpineNetMobile(
......@@ -92,15 +100,17 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
max_level=7,
use_keras_upsampling_2d=True)),
head=retinanet_cfg.RetinaNetHead(
attribute_heads=attribute_heads_config))
attribute_heads=attribute_heads_config,
use_separable_conv=True))
l2_regularizer = tf.keras.regularizers.l2(5e-5)
quantization_config = common.Quantization()
quantization_config = common.Quantization(
quantize_detection_head=quantize_detection_head)
model = factory.build_retinanet(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
_ = qat_factory.build_qat_retinanet(
qat_model = qat_factory.build_qat_retinanet(
model=model,
quantization=quantization_config,
model_config=model_config)
......@@ -109,6 +119,11 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
dict(name='att1', type='regression', size=1))
self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
dict(name='att2', type='classification', size=2))
self.assertIsInstance(
qat_model.head,
(qat_dense_prediction_heads.RetinaNetHeadQuantized
if quantize_detection_head else
dense_prediction_heads.RetinaNetHead))
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -21,6 +21,51 @@ import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.quantization import configs
_QUANTIZATION_WEIGHT_NAMES = [
'output_max', 'output_min', 'optimizer_step', 'kernel_min', 'kernel_max',
'add_three_min', 'add_three_max', 'divide_six_min', 'divide_six_max',
'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max',
'quantize_layer_min', 'quantize_layer_max',
'quantize_layer_2_min', 'quantize_layer_2_max',
'post_activation_min', 'post_activation_max',
]
_ORIGINAL_WEIGHT_NAME = [
'kernel', 'depthwise_kernel', 'gamma', 'beta', 'moving_mean',
'moving_variance', 'bias'
]
def is_quantization_weight_name(name: str) -> bool:
simple_name = name.split('/')[-1].split(':')[0]
if simple_name in _QUANTIZATION_WEIGHT_NAMES:
return True
if simple_name in _ORIGINAL_WEIGHT_NAME:
return False
raise ValueError('Variable name {} is not supported.'.format(simple_name))
def copy_original_weights(original_model: tf.keras.Model,
quantized_model: tf.keras.Model):
"""Helper function that copy the original model weights to quantized model."""
original_weight_value = original_model.get_weights()
weight_values = quantized_model.get_weights()
original_idx = 0
for idx, weight in enumerate(quantized_model.weights):
if not is_quantization_weight_name(weight.name):
if original_idx >= len(original_weight_value):
raise ValueError('Not enought original model weights.')
weight_values[idx] = original_weight_value[original_idx]
original_idx = original_idx + 1
if original_idx < len(original_weight_value):
raise ValueError('Not enought quantized model weights.')
quantized_model.set_weights(weight_values)
class LayerQuantizerHelper(object):
"""Helper class that handles quantizers."""
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for helper."""
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.quantization import helper
class HelperTest(tf.test.TestCase):
def create_simple_model(self):
return tf.keras.models.Sequential([
tf.keras.layers.Dense(8, input_shape=(16,)),
])
def test_copy_original_weights_for_simple_model_with_custom_weights(self):
one_model = self.create_simple_model()
one_weights = [np.ones_like(weight) for weight in one_model.get_weights()]
one_model.set_weights(one_weights)
qat_model = tfmot.quantization.keras.quantize_model(
self.create_simple_model())
zero_weights = [np.zeros_like(weight) for weight in qat_model.get_weights()]
qat_model.set_weights(zero_weights)
helper.copy_original_weights(one_model, qat_model)
qat_model_weights = qat_model.get_weights()
count = 0
for idx, weight in enumerate(qat_model.weights):
if not helper.is_quantization_weight_name(weight.name):
self.assertAllEqual(
qat_model_weights[idx], np.ones_like(qat_model_weights[idx]))
count += 1
self.assertLen(one_model.weights, count)
self.assertGreater(len(qat_model.weights), len(one_model.weights))
if __name__ == '__main__':
tf.test.main()
......@@ -21,6 +21,7 @@ import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks
from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers
from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
keras = tf.keras
LayerNode = tfmot.quantization.keras.graph_transformations.transforms.LayerNode
......@@ -31,18 +32,6 @@ _LAYER_NAMES = [
'Vision>SegmentationHead', 'Vision>SpatialPyramidPooling', 'Vision>ASPP'
]
_QUANTIZATION_WEIGHT_NAMES = [
'output_max', 'output_min', 'optimizer_step', 'kernel_min', 'kernel_max',
'add_three_min', 'add_three_max', 'divide_six_min', 'divide_six_max',
'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max'
]
_ORIGINAL_WEIGHT_NAME = [
'kernel', 'depthwise_kernel', 'gamma', 'beta', 'moving_mean',
'moving_variance', 'bias'
]
class CustomLayerQuantize(
tfmot.quantization.keras.graph_transformations.transforms.Transform):
......@@ -58,16 +47,6 @@ class CustomLayerQuantize(
"""See base class."""
return LayerPattern(self._original_layer_pattern)
def _is_quantization_weight_name(self, name):
simple_name = name.split('/')[-1].split(':')[0]
if simple_name in _QUANTIZATION_WEIGHT_NAMES:
return True
if simple_name in _ORIGINAL_WEIGHT_NAME:
return False
raise ValueError('Variable name {} is not supported on '
'CustomLayerQuantize({}) transform.'.format(
simple_name, self._original_layer_pattern))
def _create_layer_metadata(
self, layer_class_name: str
) -> Mapping[str, tfmot.quantization.keras.QuantizeConfig]:
......@@ -97,7 +76,7 @@ class CustomLayerQuantize(
match_idx = 0
names_and_weights = []
for name_and_weight in quantized_names_and_weights:
if not self._is_quantization_weight_name(name=name_and_weight[0]):
if not helper.is_quantization_weight_name(name=name_and_weight[0]):
name_and_weight = bottleneck_names_and_weights[match_idx]
match_idx = match_idx + 1
names_and_weights.append(name_and_weight)
......
......@@ -28,6 +28,10 @@ class RetinaNetTask(retinanet.RetinaNetTask):
def build_model(self) -> tf.keras.Model:
"""Builds RetinaNet model with QAT."""
model = super(RetinaNetTask, self).build_model()
# Call the model with dummy input to build the head part.
dummpy_input = tf.zeros([1] + self.task_config.model.input_size)
model(dummpy_input, training=True)
if self.task_config.quantization:
model = factory.build_qat_retinanet(
model,
......
......@@ -65,6 +65,7 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
task = retinanet.RetinaNetTask(config.task)
model = task.build_model()
self.assertLen(model.weights, 2393)
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment