"graphbolt/vscode:/vscode.git/clone" did not exist on "107b4347d21cc5fa85eab9263b739e2d4b978748"
Commit 4e1e0a22 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Add weight copy logic for the head part of the object detection model.

PiperOrigin-RevId: 446826259
parent 7b5f980d
# --experiment_type=retinanet_spinenet_mobile_coco_qat # --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 22.0 # COCO mAP: 23.2
# QAT only supports float32 tpu due to fake-quant op. # QAT only supports float32 tpu due to fake-quant op.
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
......
...@@ -22,6 +22,7 @@ from official.projects.qat.vision.configs import common ...@@ -22,6 +22,7 @@ from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model
from official.projects.qat.vision.modeling.heads import dense_prediction_heads as dense_prediction_heads_qat from official.projects.qat.vision.modeling.heads import dense_prediction_heads as dense_prediction_heads_qat
from official.projects.qat.vision.n_bit import schemes as n_bit_schemes from official.projects.qat.vision.n_bit import schemes as n_bit_schemes
from official.projects.qat.vision.quantization import helper
from official.projects.qat.vision.quantization import schemes from official.projects.qat.vision.quantization import schemes
from official.vision import configs from official.vision import configs
from official.vision.modeling import classification_model from official.vision.modeling import classification_model
...@@ -157,6 +158,7 @@ def build_qat_retinanet( ...@@ -157,6 +158,7 @@ def build_qat_retinanet(
head = ( head = (
dense_prediction_heads_qat.RetinaNetHeadQuantized.from_config( dense_prediction_heads_qat.RetinaNetHeadQuantized.from_config(
head.get_config())) head.get_config()))
optimized_model = retinanet_model.RetinaNetModel( optimized_model = retinanet_model.RetinaNetModel(
optimized_backbone, optimized_backbone,
model.decoder, model.decoder,
...@@ -167,6 +169,12 @@ def build_qat_retinanet( ...@@ -167,6 +169,12 @@ def build_qat_retinanet(
num_scales=model_config.anchor.num_scales, num_scales=model_config.anchor.num_scales,
aspect_ratios=model_config.anchor.aspect_ratios, aspect_ratios=model_config.anchor.aspect_ratios,
anchor_size=model_config.anchor.anchor_size) anchor_size=model_config.anchor.anchor_size)
if quantization.quantize_detection_head:
# Call the model with dummy input to build the head part.
dummpy_input = tf.zeros([1] + model_config.input_size)
optimized_model(dummpy_input, training=True)
helper.copy_original_weights(model.head, optimized_model.head)
return optimized_model return optimized_model
......
...@@ -21,12 +21,14 @@ import tensorflow as tf ...@@ -21,12 +21,14 @@ import tensorflow as tf
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import factory as qat_factory from official.projects.qat.vision.modeling import factory as qat_factory
from official.projects.qat.vision.modeling.heads import dense_prediction_heads as qat_dense_prediction_heads
from official.vision.configs import backbones from official.vision.configs import backbones
from official.vision.configs import decoders from official.vision.configs import decoders
from official.vision.configs import image_classification as classification_cfg from official.vision.configs import image_classification as classification_cfg
from official.vision.configs import retinanet as retinanet_cfg from official.vision.configs import retinanet as retinanet_cfg
from official.vision.configs import semantic_segmentation as semantic_segmentation_cfg from official.vision.configs import semantic_segmentation as semantic_segmentation_cfg
from official.vision.modeling import factory from official.vision.modeling import factory
from official.vision.modeling.heads import dense_prediction_heads
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
...@@ -67,9 +69,14 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -67,9 +69,14 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
('spinenet_mobile', (640, 640), False), ('spinenet_mobile', (640, 640), False, False),
('spinenet_mobile', (640, 640), False, True),
) )
def test_builder(self, backbone_type, input_size, has_attribute_heads): def test_builder(self,
backbone_type,
input_size,
has_attribute_heads,
quantize_detection_head):
num_classes = 2 num_classes = 2
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
...@@ -83,6 +90,7 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -83,6 +90,7 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads_config = None attribute_heads_config = None
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
num_classes=num_classes, num_classes=num_classes,
input_size=[input_size[0], input_size[1], 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type=backbone_type, type=backbone_type,
spinenet_mobile=backbones.SpineNetMobile( spinenet_mobile=backbones.SpineNetMobile(
...@@ -92,15 +100,17 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -92,15 +100,17 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
max_level=7, max_level=7,
use_keras_upsampling_2d=True)), use_keras_upsampling_2d=True)),
head=retinanet_cfg.RetinaNetHead( head=retinanet_cfg.RetinaNetHead(
attribute_heads=attribute_heads_config)) attribute_heads=attribute_heads_config,
use_separable_conv=True))
l2_regularizer = tf.keras.regularizers.l2(5e-5) l2_regularizer = tf.keras.regularizers.l2(5e-5)
quantization_config = common.Quantization() quantization_config = common.Quantization(
quantize_detection_head=quantize_detection_head)
model = factory.build_retinanet( model = factory.build_retinanet(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
_ = qat_factory.build_qat_retinanet( qat_model = qat_factory.build_qat_retinanet(
model=model, model=model,
quantization=quantization_config, quantization=quantization_config,
model_config=model_config) model_config=model_config)
...@@ -109,6 +119,11 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -109,6 +119,11 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
dict(name='att1', type='regression', size=1)) dict(name='att1', type='regression', size=1))
self.assertEqual(model_config.head.attribute_heads[1].as_dict(), self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
dict(name='att2', type='classification', size=2)) dict(name='att2', type='classification', size=2))
self.assertIsInstance(
qat_model.head,
(qat_dense_prediction_heads.RetinaNetHeadQuantized
if quantize_detection_head else
dense_prediction_heads.RetinaNetHead))
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -21,6 +21,51 @@ import tensorflow_model_optimization as tfmot ...@@ -21,6 +21,51 @@ import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.quantization import configs from official.projects.qat.vision.quantization import configs
_QUANTIZATION_WEIGHT_NAMES = [
'output_max', 'output_min', 'optimizer_step', 'kernel_min', 'kernel_max',
'add_three_min', 'add_three_max', 'divide_six_min', 'divide_six_max',
'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max',
'quantize_layer_min', 'quantize_layer_max',
'quantize_layer_2_min', 'quantize_layer_2_max',
'post_activation_min', 'post_activation_max',
]
_ORIGINAL_WEIGHT_NAME = [
'kernel', 'depthwise_kernel', 'gamma', 'beta', 'moving_mean',
'moving_variance', 'bias'
]
def is_quantization_weight_name(name: str) -> bool:
simple_name = name.split('/')[-1].split(':')[0]
if simple_name in _QUANTIZATION_WEIGHT_NAMES:
return True
if simple_name in _ORIGINAL_WEIGHT_NAME:
return False
raise ValueError('Variable name {} is not supported.'.format(simple_name))
def copy_original_weights(original_model: tf.keras.Model,
quantized_model: tf.keras.Model):
"""Helper function that copy the original model weights to quantized model."""
original_weight_value = original_model.get_weights()
weight_values = quantized_model.get_weights()
original_idx = 0
for idx, weight in enumerate(quantized_model.weights):
if not is_quantization_weight_name(weight.name):
if original_idx >= len(original_weight_value):
raise ValueError('Not enought original model weights.')
weight_values[idx] = original_weight_value[original_idx]
original_idx = original_idx + 1
if original_idx < len(original_weight_value):
raise ValueError('Not enought quantized model weights.')
quantized_model.set_weights(weight_values)
class LayerQuantizerHelper(object): class LayerQuantizerHelper(object):
"""Helper class that handles quantizers.""" """Helper class that handles quantizers."""
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for helper."""
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.quantization import helper
class HelperTest(tf.test.TestCase):
def create_simple_model(self):
return tf.keras.models.Sequential([
tf.keras.layers.Dense(8, input_shape=(16,)),
])
def test_copy_original_weights_for_simple_model_with_custom_weights(self):
one_model = self.create_simple_model()
one_weights = [np.ones_like(weight) for weight in one_model.get_weights()]
one_model.set_weights(one_weights)
qat_model = tfmot.quantization.keras.quantize_model(
self.create_simple_model())
zero_weights = [np.zeros_like(weight) for weight in qat_model.get_weights()]
qat_model.set_weights(zero_weights)
helper.copy_original_weights(one_model, qat_model)
qat_model_weights = qat_model.get_weights()
count = 0
for idx, weight in enumerate(qat_model.weights):
if not helper.is_quantization_weight_name(weight.name):
self.assertAllEqual(
qat_model_weights[idx], np.ones_like(qat_model_weights[idx]))
count += 1
self.assertLen(one_model.weights, count)
self.assertGreater(len(qat_model.weights), len(one_model.weights))
if __name__ == '__main__':
tf.test.main()
...@@ -21,6 +21,7 @@ import tensorflow_model_optimization as tfmot ...@@ -21,6 +21,7 @@ import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks
from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers
from official.projects.qat.vision.quantization import configs from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
keras = tf.keras keras = tf.keras
LayerNode = tfmot.quantization.keras.graph_transformations.transforms.LayerNode LayerNode = tfmot.quantization.keras.graph_transformations.transforms.LayerNode
...@@ -31,18 +32,6 @@ _LAYER_NAMES = [ ...@@ -31,18 +32,6 @@ _LAYER_NAMES = [
'Vision>SegmentationHead', 'Vision>SpatialPyramidPooling', 'Vision>ASPP' 'Vision>SegmentationHead', 'Vision>SpatialPyramidPooling', 'Vision>ASPP'
] ]
_QUANTIZATION_WEIGHT_NAMES = [
'output_max', 'output_min', 'optimizer_step', 'kernel_min', 'kernel_max',
'add_three_min', 'add_three_max', 'divide_six_min', 'divide_six_max',
'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max'
]
_ORIGINAL_WEIGHT_NAME = [
'kernel', 'depthwise_kernel', 'gamma', 'beta', 'moving_mean',
'moving_variance', 'bias'
]
class CustomLayerQuantize( class CustomLayerQuantize(
tfmot.quantization.keras.graph_transformations.transforms.Transform): tfmot.quantization.keras.graph_transformations.transforms.Transform):
...@@ -58,16 +47,6 @@ class CustomLayerQuantize( ...@@ -58,16 +47,6 @@ class CustomLayerQuantize(
"""See base class.""" """See base class."""
return LayerPattern(self._original_layer_pattern) return LayerPattern(self._original_layer_pattern)
def _is_quantization_weight_name(self, name):
simple_name = name.split('/')[-1].split(':')[0]
if simple_name in _QUANTIZATION_WEIGHT_NAMES:
return True
if simple_name in _ORIGINAL_WEIGHT_NAME:
return False
raise ValueError('Variable name {} is not supported on '
'CustomLayerQuantize({}) transform.'.format(
simple_name, self._original_layer_pattern))
def _create_layer_metadata( def _create_layer_metadata(
self, layer_class_name: str self, layer_class_name: str
) -> Mapping[str, tfmot.quantization.keras.QuantizeConfig]: ) -> Mapping[str, tfmot.quantization.keras.QuantizeConfig]:
...@@ -97,7 +76,7 @@ class CustomLayerQuantize( ...@@ -97,7 +76,7 @@ class CustomLayerQuantize(
match_idx = 0 match_idx = 0
names_and_weights = [] names_and_weights = []
for name_and_weight in quantized_names_and_weights: for name_and_weight in quantized_names_and_weights:
if not self._is_quantization_weight_name(name=name_and_weight[0]): if not helper.is_quantization_weight_name(name=name_and_weight[0]):
name_and_weight = bottleneck_names_and_weights[match_idx] name_and_weight = bottleneck_names_and_weights[match_idx]
match_idx = match_idx + 1 match_idx = match_idx + 1
names_and_weights.append(name_and_weight) names_and_weights.append(name_and_weight)
......
...@@ -28,6 +28,10 @@ class RetinaNetTask(retinanet.RetinaNetTask): ...@@ -28,6 +28,10 @@ class RetinaNetTask(retinanet.RetinaNetTask):
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Builds RetinaNet model with QAT.""" """Builds RetinaNet model with QAT."""
model = super(RetinaNetTask, self).build_model() model = super(RetinaNetTask, self).build_model()
# Call the model with dummy input to build the head part.
dummpy_input = tf.zeros([1] + self.task_config.model.input_size)
model(dummpy_input, training=True)
if self.task_config.quantization: if self.task_config.quantization:
model = factory.build_qat_retinanet( model = factory.build_qat_retinanet(
model, model,
......
...@@ -65,6 +65,7 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase): ...@@ -65,6 +65,7 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
task = retinanet.RetinaNetTask(config.task) task = retinanet.RetinaNetTask(config.task)
model = task.build_model() model = task.build_model()
self.assertLen(model.weights, 2393)
metrics = task.build_metrics(training=is_training) metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy() strategy = tf.distribute.get_strategy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment