Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4e1e0a22
Commit
4e1e0a22
authored
May 05, 2022
by
Jaehong Kim
Committed by
A. Unique TensorFlower
May 05, 2022
Browse files
Add weight copy logic for the head part of the object detection model.
PiperOrigin-RevId: 446826259
parent
7b5f980d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
135 additions
and
29 deletions
+135
-29
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
...riments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
+1
-1
official/projects/qat/vision/modeling/factory.py
official/projects/qat/vision/modeling/factory.py
+8
-0
official/projects/qat/vision/modeling/factory_test.py
official/projects/qat/vision/modeling/factory_test.py
+20
-5
official/projects/qat/vision/quantization/helper.py
official/projects/qat/vision/quantization/helper.py
+45
-0
official/projects/qat/vision/quantization/helper_test.py
official/projects/qat/vision/quantization/helper_test.py
+54
-0
official/projects/qat/vision/quantization/layer_transforms.py
...cial/projects/qat/vision/quantization/layer_transforms.py
+2
-23
official/projects/qat/vision/tasks/retinanet.py
official/projects/qat/vision/tasks/retinanet.py
+4
-0
official/projects/qat/vision/tasks/retinanet_test.py
official/projects/qat/vision/tasks/retinanet_test.py
+1
-0
No files found.
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
View file @
4e1e0a22
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 2
2.0
# COCO mAP: 2
3.2
# QAT only supports float32 tpu due to fake-quant op.
# QAT only supports float32 tpu due to fake-quant op.
runtime
:
runtime
:
distribution_strategy
:
'
tpu'
distribution_strategy
:
'
tpu'
...
...
official/projects/qat/vision/modeling/factory.py
View file @
4e1e0a22
...
@@ -22,6 +22,7 @@ from official.projects.qat.vision.configs import common
...
@@ -22,6 +22,7 @@ from official.projects.qat.vision.configs import common
from
official.projects.qat.vision.modeling
import
segmentation_model
as
qat_segmentation_model
from
official.projects.qat.vision.modeling
import
segmentation_model
as
qat_segmentation_model
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
as
dense_prediction_heads_qat
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
as
dense_prediction_heads_qat
from
official.projects.qat.vision.n_bit
import
schemes
as
n_bit_schemes
from
official.projects.qat.vision.n_bit
import
schemes
as
n_bit_schemes
from
official.projects.qat.vision.quantization
import
helper
from
official.projects.qat.vision.quantization
import
schemes
from
official.projects.qat.vision.quantization
import
schemes
from
official.vision
import
configs
from
official.vision
import
configs
from
official.vision.modeling
import
classification_model
from
official.vision.modeling
import
classification_model
...
@@ -157,6 +158,7 @@ def build_qat_retinanet(
...
@@ -157,6 +158,7 @@ def build_qat_retinanet(
head
=
(
head
=
(
dense_prediction_heads_qat
.
RetinaNetHeadQuantized
.
from_config
(
dense_prediction_heads_qat
.
RetinaNetHeadQuantized
.
from_config
(
head
.
get_config
()))
head
.
get_config
()))
optimized_model
=
retinanet_model
.
RetinaNetModel
(
optimized_model
=
retinanet_model
.
RetinaNetModel
(
optimized_backbone
,
optimized_backbone
,
model
.
decoder
,
model
.
decoder
,
...
@@ -167,6 +169,12 @@ def build_qat_retinanet(
...
@@ -167,6 +169,12 @@ def build_qat_retinanet(
num_scales
=
model_config
.
anchor
.
num_scales
,
num_scales
=
model_config
.
anchor
.
num_scales
,
aspect_ratios
=
model_config
.
anchor
.
aspect_ratios
,
aspect_ratios
=
model_config
.
anchor
.
aspect_ratios
,
anchor_size
=
model_config
.
anchor
.
anchor_size
)
anchor_size
=
model_config
.
anchor
.
anchor_size
)
if
quantization
.
quantize_detection_head
:
# Call the model with dummy input to build the head part.
dummpy_input
=
tf
.
zeros
([
1
]
+
model_config
.
input_size
)
optimized_model
(
dummpy_input
,
training
=
True
)
helper
.
copy_original_weights
(
model
.
head
,
optimized_model
.
head
)
return
optimized_model
return
optimized_model
...
...
official/projects/qat/vision/modeling/factory_test.py
View file @
4e1e0a22
...
@@ -21,12 +21,14 @@ import tensorflow as tf
...
@@ -21,12 +21,14 @@ import tensorflow as tf
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.modeling
import
factory
as
qat_factory
from
official.projects.qat.vision.modeling
import
factory
as
qat_factory
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
as
qat_dense_prediction_heads
from
official.vision.configs
import
backbones
from
official.vision.configs
import
backbones
from
official.vision.configs
import
decoders
from
official.vision.configs
import
decoders
from
official.vision.configs
import
image_classification
as
classification_cfg
from
official.vision.configs
import
image_classification
as
classification_cfg
from
official.vision.configs
import
retinanet
as
retinanet_cfg
from
official.vision.configs
import
retinanet
as
retinanet_cfg
from
official.vision.configs
import
semantic_segmentation
as
semantic_segmentation_cfg
from
official.vision.configs
import
semantic_segmentation
as
semantic_segmentation_cfg
from
official.vision.modeling
import
factory
from
official.vision.modeling
import
factory
from
official.vision.modeling.heads
import
dense_prediction_heads
class
ClassificationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
ClassificationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
@@ -67,9 +69,14 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -67,9 +69,14 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
class
RetinaNetBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
RetinaNetBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
'spinenet_mobile'
,
(
640
,
640
),
False
),
(
'spinenet_mobile'
,
(
640
,
640
),
False
,
False
),
(
'spinenet_mobile'
,
(
640
,
640
),
False
,
True
),
)
)
def
test_builder
(
self
,
backbone_type
,
input_size
,
has_attribute_heads
):
def
test_builder
(
self
,
backbone_type
,
input_size
,
has_attribute_heads
,
quantize_detection_head
):
num_classes
=
2
num_classes
=
2
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
[
0
],
input_size
[
1
],
3
])
shape
=
[
None
,
input_size
[
0
],
input_size
[
1
],
3
])
...
@@ -83,6 +90,7 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -83,6 +90,7 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads_config
=
None
attribute_heads_config
=
None
model_config
=
retinanet_cfg
.
RetinaNet
(
model_config
=
retinanet_cfg
.
RetinaNet
(
num_classes
=
num_classes
,
num_classes
=
num_classes
,
input_size
=
[
input_size
[
0
],
input_size
[
1
],
3
],
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
backbone_type
,
type
=
backbone_type
,
spinenet_mobile
=
backbones
.
SpineNetMobile
(
spinenet_mobile
=
backbones
.
SpineNetMobile
(
...
@@ -92,15 +100,17 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -92,15 +100,17 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
max_level
=
7
,
max_level
=
7
,
use_keras_upsampling_2d
=
True
)),
use_keras_upsampling_2d
=
True
)),
head
=
retinanet_cfg
.
RetinaNetHead
(
head
=
retinanet_cfg
.
RetinaNetHead
(
attribute_heads
=
attribute_heads_config
))
attribute_heads
=
attribute_heads_config
,
use_separable_conv
=
True
))
l2_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
5e-5
)
l2_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
5e-5
)
quantization_config
=
common
.
Quantization
()
quantization_config
=
common
.
Quantization
(
quantize_detection_head
=
quantize_detection_head
)
model
=
factory
.
build_retinanet
(
model
=
factory
.
build_retinanet
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
_
=
qat_factory
.
build_qat_retinanet
(
qat_model
=
qat_factory
.
build_qat_retinanet
(
model
=
model
,
model
=
model
,
quantization
=
quantization_config
,
quantization
=
quantization_config
,
model_config
=
model_config
)
model_config
=
model_config
)
...
@@ -109,6 +119,11 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -109,6 +119,11 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
dict
(
name
=
'att1'
,
type
=
'regression'
,
size
=
1
))
dict
(
name
=
'att1'
,
type
=
'regression'
,
size
=
1
))
self
.
assertEqual
(
model_config
.
head
.
attribute_heads
[
1
].
as_dict
(),
self
.
assertEqual
(
model_config
.
head
.
attribute_heads
[
1
].
as_dict
(),
dict
(
name
=
'att2'
,
type
=
'classification'
,
size
=
2
))
dict
(
name
=
'att2'
,
type
=
'classification'
,
size
=
2
))
self
.
assertIsInstance
(
qat_model
.
head
,
(
qat_dense_prediction_heads
.
RetinaNetHeadQuantized
if
quantize_detection_head
else
dense_prediction_heads
.
RetinaNetHead
))
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/projects/qat/vision/quantization/helper.py
View file @
4e1e0a22
...
@@ -21,6 +21,51 @@ import tensorflow_model_optimization as tfmot
...
@@ -21,6 +21,51 @@ import tensorflow_model_optimization as tfmot
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
configs
_QUANTIZATION_WEIGHT_NAMES
=
[
'output_max'
,
'output_min'
,
'optimizer_step'
,
'kernel_min'
,
'kernel_max'
,
'add_three_min'
,
'add_three_max'
,
'divide_six_min'
,
'divide_six_max'
,
'depthwise_kernel_min'
,
'depthwise_kernel_max'
,
'reduce_mean_quantizer_vars_min'
,
'reduce_mean_quantizer_vars_max'
,
'quantize_layer_min'
,
'quantize_layer_max'
,
'quantize_layer_2_min'
,
'quantize_layer_2_max'
,
'post_activation_min'
,
'post_activation_max'
,
]
_ORIGINAL_WEIGHT_NAME
=
[
'kernel'
,
'depthwise_kernel'
,
'gamma'
,
'beta'
,
'moving_mean'
,
'moving_variance'
,
'bias'
]
def
is_quantization_weight_name
(
name
:
str
)
->
bool
:
simple_name
=
name
.
split
(
'/'
)[
-
1
].
split
(
':'
)[
0
]
if
simple_name
in
_QUANTIZATION_WEIGHT_NAMES
:
return
True
if
simple_name
in
_ORIGINAL_WEIGHT_NAME
:
return
False
raise
ValueError
(
'Variable name {} is not supported.'
.
format
(
simple_name
))
def
copy_original_weights
(
original_model
:
tf
.
keras
.
Model
,
quantized_model
:
tf
.
keras
.
Model
):
"""Helper function that copy the original model weights to quantized model."""
original_weight_value
=
original_model
.
get_weights
()
weight_values
=
quantized_model
.
get_weights
()
original_idx
=
0
for
idx
,
weight
in
enumerate
(
quantized_model
.
weights
):
if
not
is_quantization_weight_name
(
weight
.
name
):
if
original_idx
>=
len
(
original_weight_value
):
raise
ValueError
(
'Not enought original model weights.'
)
weight_values
[
idx
]
=
original_weight_value
[
original_idx
]
original_idx
=
original_idx
+
1
if
original_idx
<
len
(
original_weight_value
):
raise
ValueError
(
'Not enought quantized model weights.'
)
quantized_model
.
set_weights
(
weight_values
)
class
LayerQuantizerHelper
(
object
):
class
LayerQuantizerHelper
(
object
):
"""Helper class that handles quantizers."""
"""Helper class that handles quantizers."""
...
...
official/projects/qat/vision/quantization/helper_test.py
0 → 100644
View file @
4e1e0a22
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for helper."""
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.projects.qat.vision.quantization
import
helper
class
HelperTest
(
tf
.
test
.
TestCase
):
def
create_simple_model
(
self
):
return
tf
.
keras
.
models
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
8
,
input_shape
=
(
16
,)),
])
def
test_copy_original_weights_for_simple_model_with_custom_weights
(
self
):
one_model
=
self
.
create_simple_model
()
one_weights
=
[
np
.
ones_like
(
weight
)
for
weight
in
one_model
.
get_weights
()]
one_model
.
set_weights
(
one_weights
)
qat_model
=
tfmot
.
quantization
.
keras
.
quantize_model
(
self
.
create_simple_model
())
zero_weights
=
[
np
.
zeros_like
(
weight
)
for
weight
in
qat_model
.
get_weights
()]
qat_model
.
set_weights
(
zero_weights
)
helper
.
copy_original_weights
(
one_model
,
qat_model
)
qat_model_weights
=
qat_model
.
get_weights
()
count
=
0
for
idx
,
weight
in
enumerate
(
qat_model
.
weights
):
if
not
helper
.
is_quantization_weight_name
(
weight
.
name
):
self
.
assertAllEqual
(
qat_model_weights
[
idx
],
np
.
ones_like
(
qat_model_weights
[
idx
]))
count
+=
1
self
.
assertLen
(
one_model
.
weights
,
count
)
self
.
assertGreater
(
len
(
qat_model
.
weights
),
len
(
one_model
.
weights
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/qat/vision/quantization/layer_transforms.py
View file @
4e1e0a22
...
@@ -21,6 +21,7 @@ import tensorflow_model_optimization as tfmot
...
@@ -21,6 +21,7 @@ import tensorflow_model_optimization as tfmot
from
official.projects.qat.vision.modeling.layers
import
nn_blocks
as
quantized_nn_blocks
from
official.projects.qat.vision.modeling.layers
import
nn_blocks
as
quantized_nn_blocks
from
official.projects.qat.vision.modeling.layers
import
nn_layers
as
quantized_nn_layers
from
official.projects.qat.vision.modeling.layers
import
nn_layers
as
quantized_nn_layers
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
helper
keras
=
tf
.
keras
keras
=
tf
.
keras
LayerNode
=
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
LayerNode
LayerNode
=
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
LayerNode
...
@@ -31,18 +32,6 @@ _LAYER_NAMES = [
...
@@ -31,18 +32,6 @@ _LAYER_NAMES = [
'Vision>SegmentationHead'
,
'Vision>SpatialPyramidPooling'
,
'Vision>ASPP'
'Vision>SegmentationHead'
,
'Vision>SpatialPyramidPooling'
,
'Vision>ASPP'
]
]
_QUANTIZATION_WEIGHT_NAMES
=
[
'output_max'
,
'output_min'
,
'optimizer_step'
,
'kernel_min'
,
'kernel_max'
,
'add_three_min'
,
'add_three_max'
,
'divide_six_min'
,
'divide_six_max'
,
'depthwise_kernel_min'
,
'depthwise_kernel_max'
,
'reduce_mean_quantizer_vars_min'
,
'reduce_mean_quantizer_vars_max'
]
_ORIGINAL_WEIGHT_NAME
=
[
'kernel'
,
'depthwise_kernel'
,
'gamma'
,
'beta'
,
'moving_mean'
,
'moving_variance'
,
'bias'
]
class
CustomLayerQuantize
(
class
CustomLayerQuantize
(
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
Transform
):
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
Transform
):
...
@@ -58,16 +47,6 @@ class CustomLayerQuantize(
...
@@ -58,16 +47,6 @@ class CustomLayerQuantize(
"""See base class."""
"""See base class."""
return
LayerPattern
(
self
.
_original_layer_pattern
)
return
LayerPattern
(
self
.
_original_layer_pattern
)
def
_is_quantization_weight_name
(
self
,
name
):
simple_name
=
name
.
split
(
'/'
)[
-
1
].
split
(
':'
)[
0
]
if
simple_name
in
_QUANTIZATION_WEIGHT_NAMES
:
return
True
if
simple_name
in
_ORIGINAL_WEIGHT_NAME
:
return
False
raise
ValueError
(
'Variable name {} is not supported on '
'CustomLayerQuantize({}) transform.'
.
format
(
simple_name
,
self
.
_original_layer_pattern
))
def
_create_layer_metadata
(
def
_create_layer_metadata
(
self
,
layer_class_name
:
str
self
,
layer_class_name
:
str
)
->
Mapping
[
str
,
tfmot
.
quantization
.
keras
.
QuantizeConfig
]:
)
->
Mapping
[
str
,
tfmot
.
quantization
.
keras
.
QuantizeConfig
]:
...
@@ -97,7 +76,7 @@ class CustomLayerQuantize(
...
@@ -97,7 +76,7 @@ class CustomLayerQuantize(
match_idx
=
0
match_idx
=
0
names_and_weights
=
[]
names_and_weights
=
[]
for
name_and_weight
in
quantized_names_and_weights
:
for
name_and_weight
in
quantized_names_and_weights
:
if
not
s
el
f
.
_
is_quantization_weight_name
(
name
=
name_and_weight
[
0
]):
if
not
h
el
per
.
is_quantization_weight_name
(
name
=
name_and_weight
[
0
]):
name_and_weight
=
bottleneck_names_and_weights
[
match_idx
]
name_and_weight
=
bottleneck_names_and_weights
[
match_idx
]
match_idx
=
match_idx
+
1
match_idx
=
match_idx
+
1
names_and_weights
.
append
(
name_and_weight
)
names_and_weights
.
append
(
name_and_weight
)
...
...
official/projects/qat/vision/tasks/retinanet.py
View file @
4e1e0a22
...
@@ -28,6 +28,10 @@ class RetinaNetTask(retinanet.RetinaNetTask):
...
@@ -28,6 +28,10 @@ class RetinaNetTask(retinanet.RetinaNetTask):
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds RetinaNet model with QAT."""
"""Builds RetinaNet model with QAT."""
model
=
super
(
RetinaNetTask
,
self
).
build_model
()
model
=
super
(
RetinaNetTask
,
self
).
build_model
()
# Call the model with dummy input to build the head part.
dummpy_input
=
tf
.
zeros
([
1
]
+
self
.
task_config
.
model
.
input_size
)
model
(
dummpy_input
,
training
=
True
)
if
self
.
task_config
.
quantization
:
if
self
.
task_config
.
quantization
:
model
=
factory
.
build_qat_retinanet
(
model
=
factory
.
build_qat_retinanet
(
model
,
model
,
...
...
official/projects/qat/vision/tasks/retinanet_test.py
View file @
4e1e0a22
...
@@ -65,6 +65,7 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -65,6 +65,7 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
task
=
retinanet
.
RetinaNetTask
(
config
.
task
)
task
=
retinanet
.
RetinaNetTask
(
config
.
task
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
self
.
assertLen
(
model
.
weights
,
2393
)
metrics
=
task
.
build_metrics
(
training
=
is_training
)
metrics
=
task
.
build_metrics
(
training
=
is_training
)
strategy
=
tf
.
distribute
.
get_strategy
()
strategy
=
tf
.
distribute
.
get_strategy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment