Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ce2aea2f
Commit
ce2aea2f
authored
Feb 11, 2022
by
Fan Yang
Committed by
A. Unique TensorFlower
Feb 14, 2022
Browse files
Internal change
PiperOrigin-RevId: 428078415
parent
94696ab1
Changes
53
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1118 additions
and
1 deletion
+1118
-1
official/projects/qat/vision/quantization/configs.py
official/projects/qat/vision/quantization/configs.py
+337
-0
official/projects/qat/vision/quantization/configs_test.py
official/projects/qat/vision/quantization/configs_test.py
+202
-0
official/projects/qat/vision/quantization/schemes.py
official/projects/qat/vision/quantization/schemes.py
+195
-0
official/projects/qat/vision/registry_imports.py
official/projects/qat/vision/registry_imports.py
+21
-0
official/projects/qat/vision/tasks/__init__.py
official/projects/qat/vision/tasks/__init__.py
+18
-0
official/projects/qat/vision/tasks/image_classification.py
official/projects/qat/vision/tasks/image_classification.py
+50
-0
official/projects/qat/vision/tasks/image_classification_test.py
...al/projects/qat/vision/tasks/image_classification_test.py
+62
-0
official/projects/qat/vision/tasks/retinanet.py
official/projects/qat/vision/tasks/retinanet.py
+36
-0
official/projects/qat/vision/tasks/retinanet_test.py
official/projects/qat/vision/tasks/retinanet_test.py
+67
-0
official/projects/qat/vision/tasks/semantic_segmentation.py
official/projects/qat/vision/tasks/semantic_segmentation.py
+36
-0
official/projects/qat/vision/tasks/semantic_segmentation_test.py
...l/projects/qat/vision/tasks/semantic_segmentation_test.py
+67
-0
official/projects/qat/vision/train.py
official/projects/qat/vision/train.py
+26
-0
official/vision/beta/configs/semantic_segmentation.py
official/vision/beta/configs/semantic_segmentation.py
+1
-1
No files found.
official/projects/qat/vision/quantization/configs.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Default 8-bit QuantizeConfigs."""
from
typing
import
Sequence
,
Callable
,
Tuple
,
Any
,
Dict
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
Quantizer
=
tfmot
.
quantization
.
keras
.
quantizers
.
Quantizer
Layer
=
tf
.
keras
.
layers
.
Layer
Activation
=
Callable
[[
tf
.
Tensor
],
tf
.
Tensor
]
WeightAndQuantizer
=
Tuple
[
tf
.
Variable
,
Quantizer
]
ActivationAndQuantizer
=
Tuple
[
Activation
,
Quantizer
]
class
Default8BitOutputQuantizeConfig
(
tfmot
.
quantization
.
keras
.
QuantizeConfig
):
"""QuantizeConfig which only quantizes the output from a layer."""
def
get_weights_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
WeightAndQuantizer
]:
return
[]
def
get_activations_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
ActivationAndQuantizer
]:
return
[]
def
set_quantize_weights
(
self
,
layer
:
Layer
,
quantize_weights
:
Sequence
[
tf
.
Tensor
]):
pass
def
set_quantize_activations
(
self
,
layer
:
Layer
,
quantize_activations
:
Sequence
[
Activation
]):
pass
def
get_output_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
Quantizer
]:
return
[
tfmot
.
quantization
.
keras
.
quantizers
.
MovingAverageQuantizer
(
num_bits
=
8
,
per_axis
=
False
,
symmetric
=
False
,
narrow_range
=
False
)
]
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
return
{}
class
NoOpQuantizeConfig
(
tfmot
.
quantization
.
keras
.
QuantizeConfig
):
"""QuantizeConfig which does not quantize any part of the layer."""
def
get_weights_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
WeightAndQuantizer
]:
return
[]
def
get_activations_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
ActivationAndQuantizer
]:
return
[]
def
set_quantize_weights
(
self
,
layer
:
Layer
,
quantize_weights
:
Sequence
[
tf
.
Tensor
]):
pass
def
set_quantize_activations
(
self
,
layer
:
Layer
,
quantize_activations
:
Sequence
[
Activation
]):
pass
def
get_output_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
Quantizer
]:
return
[]
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
return
{}
class
Default8BitQuantizeConfig
(
tfmot
.
quantization
.
keras
.
QuantizeConfig
):
"""QuantizeConfig for non recurrent Keras layers."""
def
__init__
(
self
,
weight_attrs
:
Sequence
[
str
],
activation_attrs
:
Sequence
[
str
],
quantize_output
:
bool
):
"""Initializes a default 8bit quantize config."""
self
.
weight_attrs
=
weight_attrs
self
.
activation_attrs
=
activation_attrs
self
.
quantize_output
=
quantize_output
# TODO(pulkitb): For some layers such as Conv2D, per_axis should be True.
# Add mapping for which layers support per_axis.
self
.
weight_quantizer
=
tfmot
.
quantization
.
keras
.
quantizers
.
LastValueQuantizer
(
num_bits
=
8
,
per_axis
=
False
,
symmetric
=
True
,
narrow_range
=
True
)
self
.
activation_quantizer
=
tfmot
.
quantization
.
keras
.
quantizers
.
MovingAverageQuantizer
(
num_bits
=
8
,
per_axis
=
False
,
symmetric
=
False
,
narrow_range
=
False
)
def
get_weights_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
WeightAndQuantizer
]:
"""See base class."""
return
[(
getattr
(
layer
,
weight_attr
),
self
.
weight_quantizer
)
for
weight_attr
in
self
.
weight_attrs
]
def
get_activations_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
ActivationAndQuantizer
]:
"""See base class."""
return
[(
getattr
(
layer
,
activation_attr
),
self
.
activation_quantizer
)
for
activation_attr
in
self
.
activation_attrs
]
def
set_quantize_weights
(
self
,
layer
:
Layer
,
quantize_weights
:
Sequence
[
tf
.
Tensor
]):
"""See base class."""
if
len
(
self
.
weight_attrs
)
!=
len
(
quantize_weights
):
raise
ValueError
(
'`set_quantize_weights` called on layer {} with {} '
'weight parameters, but layer expects {} values.'
.
format
(
layer
.
name
,
len
(
quantize_weights
),
len
(
self
.
weight_attrs
)))
for
weight_attr
,
weight
in
zip
(
self
.
weight_attrs
,
quantize_weights
):
current_weight
=
getattr
(
layer
,
weight_attr
)
if
current_weight
.
shape
!=
weight
.
shape
:
raise
ValueError
(
'Existing layer weight shape {} is incompatible with'
'provided weight shape {}'
.
format
(
current_weight
.
shape
,
weight
.
shape
))
setattr
(
layer
,
weight_attr
,
weight
)
def
set_quantize_activations
(
self
,
layer
:
Layer
,
quantize_activations
:
Sequence
[
Activation
]):
"""See base class."""
if
len
(
self
.
activation_attrs
)
!=
len
(
quantize_activations
):
raise
ValueError
(
'`set_quantize_activations` called on layer {} with {} '
'activation parameters, but layer expects {} values.'
.
format
(
layer
.
name
,
len
(
quantize_activations
),
len
(
self
.
activation_attrs
)))
for
activation_attr
,
activation
in
zip
(
self
.
activation_attrs
,
quantize_activations
):
setattr
(
layer
,
activation_attr
,
activation
)
def
get_output_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
Quantizer
]:
"""See base class."""
if
self
.
quantize_output
:
return
[
self
.
activation_quantizer
]
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
object
:
"""Instantiates a `Default8BitQuantizeConfig` from its config.
Args:
config: Output of `get_config()`.
Returns:
A `Default8BitQuantizeConfig` instance.
"""
return
cls
(
**
config
)
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
"""Get a config for this quantize config."""
# TODO(pulkitb): Add weight and activation quantizer to config.
# Currently it's created internally, but ideally the quantizers should be
# part of the constructor and passed in from the registry.
return
{
'weight_attrs'
:
self
.
weight_attrs
,
'activation_attrs'
:
self
.
activation_attrs
,
'quantize_output'
:
self
.
quantize_output
}
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
Default8BitQuantizeConfig
):
return
False
return
(
self
.
weight_attrs
==
other
.
weight_attrs
and
self
.
activation_attrs
==
self
.
activation_attrs
and
self
.
weight_quantizer
==
other
.
weight_quantizer
and
self
.
activation_quantizer
==
other
.
activation_quantizer
and
self
.
quantize_output
==
other
.
quantize_output
)
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
class
Default8BitConvWeightsQuantizer
(
tfmot
.
quantization
.
keras
.
quantizers
.
LastValueQuantizer
):
"""Quantizer for handling weights in Conv2D/DepthwiseConv2D layers."""
def
__init__
(
self
):
"""Construct LastValueQuantizer with params specific for TFLite Convs."""
super
(
Default8BitConvWeightsQuantizer
,
self
).
__init__
(
num_bits
=
8
,
per_axis
=
True
,
symmetric
=
True
,
narrow_range
=
True
)
def
build
(
self
,
tensor_shape
:
tf
.
TensorShape
,
name
:
str
,
layer
:
Layer
):
"""Build min/max quantization variables."""
min_weight
=
layer
.
add_weight
(
name
+
'_min'
,
shape
=
(
tensor_shape
[
-
1
],),
initializer
=
tf
.
keras
.
initializers
.
Constant
(
-
6.0
),
trainable
=
False
)
max_weight
=
layer
.
add_weight
(
name
+
'_max'
,
shape
=
(
tensor_shape
[
-
1
],),
initializer
=
tf
.
keras
.
initializers
.
Constant
(
6.0
),
trainable
=
False
)
return
{
'min_var'
:
min_weight
,
'max_var'
:
max_weight
}
class
NoQuantizer
(
tfmot
.
quantization
.
keras
.
quantizers
.
Quantizer
):
"""Dummy quantizer for explicitly not quantize."""
def
__call__
(
self
,
inputs
,
training
,
weights
,
**
kwargs
):
return
tf
.
identity
(
inputs
)
def
get_config
(
self
):
return
{}
def
build
(
self
,
tensor_shape
,
name
,
layer
):
return
{}
class
Default8BitConvQuantizeConfig
(
Default8BitQuantizeConfig
):
"""QuantizeConfig for Conv2D/DepthwiseConv2D layers."""
def
__init__
(
self
,
weight_attrs
:
Sequence
[
str
],
activation_attrs
:
Sequence
[
str
],
quantize_output
:
bool
):
"""Initializes default 8bit quantization config for the conv layer."""
super
().
__init__
(
weight_attrs
,
activation_attrs
,
quantize_output
)
self
.
weight_quantizer
=
Default8BitConvWeightsQuantizer
()
class
Default8BitActivationQuantizeConfig
(
tfmot
.
quantization
.
keras
.
QuantizeConfig
):
"""QuantizeConfig for keras.layers.Activation.
`keras.layers.Activation` needs a separate `QuantizeConfig` since the
decision to quantize depends on the specific activation type.
"""
def
_assert_activation_layer
(
self
,
layer
:
Layer
):
if
not
isinstance
(
layer
,
tf
.
keras
.
layers
.
Activation
):
raise
RuntimeError
(
'Default8BitActivationQuantizeConfig can only be used with '
'`keras.layers.Activation`.'
)
def
get_weights_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
WeightAndQuantizer
]:
"""See base class."""
self
.
_assert_activation_layer
(
layer
)
return
[]
def
get_activations_and_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
ActivationAndQuantizer
]:
"""See base class."""
self
.
_assert_activation_layer
(
layer
)
return
[]
def
set_quantize_weights
(
self
,
layer
:
Layer
,
quantize_weights
:
Sequence
[
tf
.
Tensor
]):
"""See base class."""
self
.
_assert_activation_layer
(
layer
)
def
set_quantize_activations
(
self
,
layer
:
Layer
,
quantize_activations
:
Sequence
[
Activation
]):
"""See base class."""
self
.
_assert_activation_layer
(
layer
)
def
get_output_quantizers
(
self
,
layer
:
Layer
)
->
Sequence
[
Quantizer
]:
"""See base class."""
self
.
_assert_activation_layer
(
layer
)
if
not
hasattr
(
layer
.
activation
,
'__name__'
):
raise
ValueError
(
'Activation {} not supported by '
'Default8BitActivationQuantizeConfig.'
.
format
(
layer
.
activation
))
# This code is copied from TFMOT repo, but added relu6 to support mobilenet.
if
layer
.
activation
.
__name__
in
[
'relu'
,
'relu6'
,
'swish'
,
'hard_swish'
]:
# 'relu' should generally get fused into the previous layer.
return
[
tfmot
.
quantization
.
keras
.
quantizers
.
MovingAverageQuantizer
(
num_bits
=
8
,
per_axis
=
False
,
symmetric
=
False
,
narrow_range
=
False
)]
elif
layer
.
activation
.
__name__
in
[
'linear'
,
'softmax'
,
'sigmoid'
,
'hard_sigmoid'
]:
return
[]
raise
ValueError
(
'Activation {} not supported by '
'Default8BitActivationQuantizeConfig.'
.
format
(
layer
.
activation
))
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
"""Get a config for this quantizer config."""
return
{}
def
_types_dict
():
return
{
'Default8BitOutputQuantizeConfig'
:
Default8BitOutputQuantizeConfig
,
'NoOpQuantizeConfig'
:
NoOpQuantizeConfig
,
'Default8BitQuantizeConfig'
:
Default8BitQuantizeConfig
,
'Default8BitConvWeightsQuantizer'
:
Default8BitConvWeightsQuantizer
,
'Default8BitConvQuantizeConfig'
:
Default8BitConvQuantizeConfig
,
'Default8BitActivationQuantizeConfig'
:
Default8BitActivationQuantizeConfig
,
}
official/projects/qat/vision/quantization/configs_test.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for configs.py."""
# Import libraries
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.projects.qat.vision.quantization
import
configs
class
_TestHelper
(
object
):
def
_convert_list
(
self
,
list_of_tuples
):
"""Transforms a list of 2-tuples to a tuple of 2 lists.
`QuantizeConfig` methods return a list of 2-tuples in the form
[(weight1, quantizer1), (weight2, quantizer2)]. This function converts
it into a 2-tuple of lists. ([weight1, weight2]), (quantizer1, quantizer2).
Args:
list_of_tuples: List of 2-tuples.
Returns:
2-tuple of lists.
"""
list1
=
[]
list2
=
[]
for
a
,
b
in
list_of_tuples
:
list1
.
append
(
a
)
list2
.
append
(
b
)
return
list1
,
list2
# TODO(pulkitb): Consider asserting on full equality for quantizers.
def
_assert_weight_quantizers
(
self
,
quantizer_list
):
for
quantizer
in
quantizer_list
:
self
.
assertIsInstance
(
quantizer
,
tfmot
.
quantization
.
keras
.
quantizers
.
LastValueQuantizer
)
def
_assert_activation_quantizers
(
self
,
quantizer_list
):
for
quantizer
in
quantizer_list
:
self
.
assertIsInstance
(
quantizer
,
tfmot
.
quantization
.
keras
.
quantizers
.
MovingAverageQuantizer
)
def
_assert_kernel_equality
(
self
,
a
,
b
):
self
.
assertAllEqual
(
a
.
numpy
(),
b
.
numpy
())
class
Default8BitQuantizeConfigTest
(
tf
.
test
.
TestCase
,
_TestHelper
):
def
_simple_dense_layer
(
self
):
layer
=
tf
.
keras
.
layers
.
Dense
(
2
)
layer
.
build
(
input_shape
=
(
3
,))
return
layer
def
testGetsQuantizeWeightsAndQuantizers
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
(
weights
,
weight_quantizers
)
=
self
.
_convert_list
(
quantize_config
.
get_weights_and_quantizers
(
layer
))
self
.
_assert_weight_quantizers
(
weight_quantizers
)
self
.
assertEqual
([
layer
.
kernel
],
weights
)
def
testGetsQuantizeActivationsAndQuantizers
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
(
activations
,
activation_quantizers
)
=
self
.
_convert_list
(
quantize_config
.
get_activations_and_quantizers
(
layer
))
self
.
_assert_activation_quantizers
(
activation_quantizers
)
self
.
assertEqual
([
layer
.
activation
],
activations
)
def
testSetsQuantizeWeights
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_kernel
=
tf
.
keras
.
backend
.
variable
(
np
.
ones
(
layer
.
kernel
.
shape
.
as_list
()))
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
quantize_config
.
set_quantize_weights
(
layer
,
[
quantize_kernel
])
self
.
_assert_kernel_equality
(
layer
.
kernel
,
quantize_kernel
)
def
testSetsQuantizeActivations
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_activation
=
tf
.
keras
.
activations
.
relu
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
quantize_config
.
set_quantize_activations
(
layer
,
[
quantize_activation
])
self
.
assertEqual
(
layer
.
activation
,
quantize_activation
)
def
testSetsQuantizeWeights_ErrorOnWrongNumberOfWeights
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_kernel
=
tf
.
keras
.
backend
.
variable
(
np
.
ones
(
layer
.
kernel
.
shape
.
as_list
()))
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
with
self
.
assertRaises
(
ValueError
):
quantize_config
.
set_quantize_weights
(
layer
,
[])
with
self
.
assertRaises
(
ValueError
):
quantize_config
.
set_quantize_weights
(
layer
,
[
quantize_kernel
,
quantize_kernel
])
def
testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_kernel
=
tf
.
keras
.
backend
.
variable
(
np
.
ones
([
1
,
2
]))
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
with
self
.
assertRaises
(
ValueError
):
quantize_config
.
set_quantize_weights
(
layer
,
[
quantize_kernel
])
def
testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_activation
=
tf
.
keras
.
activations
.
relu
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
with
self
.
assertRaises
(
ValueError
):
quantize_config
.
set_quantize_activations
(
layer
,
[])
with
self
.
assertRaises
(
ValueError
):
quantize_config
.
set_quantize_activations
(
layer
,
[
quantize_activation
,
quantize_activation
])
def
testGetsResultQuantizers_ReturnsQuantizer
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[],
[],
True
)
output_quantizers
=
quantize_config
.
get_output_quantizers
(
layer
)
self
.
assertLen
(
output_quantizers
,
1
)
self
.
_assert_activation_quantizers
(
output_quantizers
)
def
testGetsResultQuantizers_EmptyWhenFalse
(
self
):
layer
=
self
.
_simple_dense_layer
()
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[],
[],
False
)
output_quantizers
=
quantize_config
.
get_output_quantizers
(
layer
)
self
.
assertEqual
([],
output_quantizers
)
def
testSerialization
(
self
):
quantize_config
=
configs
.
Default8BitQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
)
expected_config
=
{
'class_name'
:
'Default8BitQuantizeConfig'
,
'config'
:
{
'weight_attrs'
:
[
'kernel'
],
'activation_attrs'
:
[
'activation'
],
'quantize_output'
:
False
}
}
serialized_quantize_config
=
tf
.
keras
.
utils
.
serialize_keras_object
(
quantize_config
)
self
.
assertEqual
(
expected_config
,
serialized_quantize_config
)
quantize_config_from_config
=
tf
.
keras
.
utils
.
deserialize_keras_object
(
serialized_quantize_config
,
module_objects
=
globals
(),
custom_objects
=
configs
.
_types_dict
())
self
.
assertEqual
(
quantize_config
,
quantize_config_from_config
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/qat/vision/quantization/schemes.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantization schemes."""
from
typing
import
Type
# Import libraries
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.projects.qat.vision.modeling.layers
import
nn_blocks
as
quantized_nn_blocks
from
official.projects.qat.vision.modeling.layers
import
nn_layers
as
quantized_nn_layers
from
official.projects.qat.vision.quantization
import
configs
keras
=
tf
.
keras
default_8bit_transforms
=
tfmot
.
quantization
.
keras
.
default_8bit
.
default_8bit_transforms
LayerNode
=
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
LayerNode
LayerPattern
=
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
LayerPattern
_QUANTIZATION_WEIGHT_NAMES
=
[
'output_max'
,
'output_min'
,
'optimizer_step'
,
'kernel_min'
,
'kernel_max'
,
'depthwise_kernel_min'
,
'depthwise_kernel_max'
,
'reduce_mean_quantizer_vars_min'
,
'reduce_mean_quantizer_vars_max'
]
_ORIGINAL_WEIGHT_NAME
=
[
'kernel'
,
'depthwise_kernel'
,
'gamma'
,
'beta'
,
'moving_mean'
,
'moving_variance'
,
'bias'
]
class
CustomLayerQuantize
(
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
Transform
):
"""Add QAT support for Keras Custom layer."""
def
__init__
(
self
,
original_layer_pattern
:
str
,
quantized_layer_class
:
Type
[
keras
.
layers
.
Layer
]):
super
(
CustomLayerQuantize
,
self
).
__init__
()
self
.
_original_layer_pattern
=
original_layer_pattern
self
.
_quantized_layer_class
=
quantized_layer_class
def
pattern
(
self
)
->
LayerPattern
:
"""See base class."""
return
LayerPattern
(
self
.
_original_layer_pattern
)
def
_is_quantization_weight_name
(
self
,
name
):
simple_name
=
name
.
split
(
'/'
)[
-
1
].
split
(
':'
)[
0
]
if
simple_name
in
_QUANTIZATION_WEIGHT_NAMES
:
return
True
if
simple_name
in
_ORIGINAL_WEIGHT_NAME
:
return
False
raise
ValueError
(
'Variable name {} is not supported on '
'CustomLayerQuantize({}) transform.'
.
format
(
simple_name
,
self
.
_original_layer_pattern
))
def
replacement
(
self
,
match_layer
:
LayerNode
)
->
LayerNode
:
"""See base class."""
bottleneck_layer
=
match_layer
.
layer
bottleneck_config
=
bottleneck_layer
[
'config'
]
bottleneck_names_and_weights
=
list
(
match_layer
.
names_and_weights
)
quantized_layer
=
self
.
_quantized_layer_class
(
**
bottleneck_config
)
dummy_input_shape
=
[
1
,
64
,
128
,
1
]
# SegmentationHead layer requires a tuple of 2 tensors.
if
isinstance
(
quantized_layer
,
quantized_nn_layers
.
SegmentationHeadQuantized
):
dummy_input_shape
=
([
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
])
quantized_layer
.
compute_output_shape
(
dummy_input_shape
)
quantized_names_and_weights
=
zip
(
[
weight
.
name
for
weight
in
quantized_layer
.
weights
],
quantized_layer
.
get_weights
())
match_idx
=
0
names_and_weights
=
[]
for
name_and_weight
in
quantized_names_and_weights
:
if
not
self
.
_is_quantization_weight_name
(
name
=
name_and_weight
[
0
]):
name_and_weight
=
bottleneck_names_and_weights
[
match_idx
]
match_idx
=
match_idx
+
1
names_and_weights
.
append
(
name_and_weight
)
if
match_idx
!=
len
(
bottleneck_names_and_weights
):
raise
ValueError
(
'{}/{} of Bottleneck weights is transformed.'
.
format
(
match_idx
,
len
(
bottleneck_names_and_weights
)))
quantized_layer_config
=
keras
.
layers
.
serialize
(
quantized_layer
)
quantized_layer_config
[
'name'
]
=
quantized_layer_config
[
'config'
][
'name'
]
if
bottleneck_layer
[
'class_name'
]
in
[
'Vision>Conv2DBNBlock'
,
'Vision>InvertedBottleneckBlock'
,
'Vision>SegmentationHead'
,
'Vision>SpatialPyramidPooling'
,
'Vision>ASPP'
,
# TODO(yeqing): Removes the Beta layers.
'Beta>Conv2DBNBlock'
,
'Beta>InvertedBottleneckBlock'
,
'Beta>SegmentationHead'
,
'Beta>SpatialPyramidPooling'
,
'Beta>ASPP'
]:
layer_metadata
=
{
'quantize_config'
:
configs
.
NoOpQuantizeConfig
()}
else
:
layer_metadata
=
{
'quantize_config'
:
configs
.
Default8BitOutputQuantizeConfig
()
}
return
LayerNode
(
quantized_layer_config
,
metadata
=
layer_metadata
,
names_and_weights
=
names_and_weights
)
class
QuantizeLayoutTransform
(
tfmot
.
quantization
.
keras
.
QuantizeLayoutTransform
):
"""Default model transformations."""
def
apply
(
self
,
model
,
layer_quantize_map
):
"""Implement default 8-bit transforms.
Currently this means the following.
1. Pull activations into layers, and apply fuse activations. (TODO)
2. Modify range in incoming layers for Concat. (TODO)
3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.
Args:
model: Keras model to be quantized.
layer_quantize_map: Map with keys as layer names, and values as dicts
containing custom `QuantizeConfig`s which may have been passed with
layers.
Returns:
(Transformed Keras model to better match TensorFlow Lite backend, updated
layer quantize map.)
"""
transforms
=
[
default_8bit_transforms
.
InputLayerQuantize
(),
default_8bit_transforms
.
SeparableConv1DQuantize
(),
default_8bit_transforms
.
SeparableConvQuantize
(),
default_8bit_transforms
.
Conv2DReshapeBatchNormReLUQuantize
(),
default_8bit_transforms
.
Conv2DReshapeBatchNormActivationQuantize
(),
default_8bit_transforms
.
Conv2DBatchNormReLUQuantize
(),
default_8bit_transforms
.
Conv2DBatchNormActivationQuantize
(),
default_8bit_transforms
.
Conv2DReshapeBatchNormQuantize
(),
default_8bit_transforms
.
Conv2DBatchNormQuantize
(),
default_8bit_transforms
.
ConcatTransform6Inputs
(),
default_8bit_transforms
.
ConcatTransform5Inputs
(),
default_8bit_transforms
.
ConcatTransform4Inputs
(),
default_8bit_transforms
.
ConcatTransform3Inputs
(),
default_8bit_transforms
.
ConcatTransform
(),
default_8bit_transforms
.
LayerReLUQuantize
(),
default_8bit_transforms
.
LayerReluActivationQuantize
(),
CustomLayerQuantize
(
'Vision>BottleneckBlock'
,
quantized_nn_blocks
.
BottleneckBlockQuantized
),
CustomLayerQuantize
(
'Vision>InvertedBottleneckBlock'
,
quantized_nn_blocks
.
InvertedBottleneckBlockQuantized
),
CustomLayerQuantize
(
'Vision>Conv2DBNBlock'
,
quantized_nn_blocks
.
Conv2DBNBlockQuantized
),
CustomLayerQuantize
(
'Vision>SegmentationHead'
,
quantized_nn_layers
.
SegmentationHeadQuantized
),
CustomLayerQuantize
(
'Vision>SpatialPyramidPooling'
,
quantized_nn_layers
.
SpatialPyramidPoolingQuantized
),
CustomLayerQuantize
(
'Vision>ASPP'
,
quantized_nn_layers
.
ASPPQuantized
),
# TODO(yeqing): Remove the `Beta` components.
CustomLayerQuantize
(
'Beta>BottleneckBlock'
,
quantized_nn_blocks
.
BottleneckBlockQuantized
),
CustomLayerQuantize
(
'Beta>InvertedBottleneckBlock'
,
quantized_nn_blocks
.
InvertedBottleneckBlockQuantized
),
CustomLayerQuantize
(
'Beta>Conv2DBNBlock'
,
quantized_nn_blocks
.
Conv2DBNBlockQuantized
),
CustomLayerQuantize
(
'Beta>SegmentationHead'
,
quantized_nn_layers
.
SegmentationHeadQuantized
),
CustomLayerQuantize
(
'Beta>SpatialPyramidPooling'
,
quantized_nn_layers
.
SpatialPyramidPoolingQuantized
),
CustomLayerQuantize
(
'Beta>ASPP'
,
quantized_nn_layers
.
ASPPQuantized
)
]
return
tfmot
.
quantization
.
keras
.
graph_transformations
.
model_transformer
.
ModelTransformer
(
model
,
transforms
,
set
(
layer_quantize_map
.
keys
()),
layer_quantize_map
).
transform
()
class
Default8BitQuantizeScheme
(
tfmot
.
quantization
.
keras
.
default_8bit
.
Default8BitQuantizeScheme
):
def
get_layout_transformer
(
self
):
return
QuantizeLayoutTransform
()
official/projects/qat/vision/registry_imports.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration on qat project."""
# pylint: disable=unused-import
from
official.projects.qat.vision
import
configs
from
official.projects.qat.vision.modeling
import
layers
from
official.projects.qat.vision.tasks
import
image_classification
from
official.projects.qat.vision.tasks
import
retinanet
from
official.projects.qat.vision.tasks
import
semantic_segmentation
official/projects/qat/vision/tasks/__init__.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tasks package definition."""
from
official.projects.qat.vision.tasks
import
image_classification
official/projects/qat/vision/tasks/image_classification.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
import
tensorflow
as
tf
from
official.core
import
task_factory
from
official.projects.qat.vision.configs
import
image_classification
as
exp_cfg
from
official.projects.qat.vision.modeling
import
factory
from
official.vision.beta.tasks
import
image_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
class
ImageClassificationTask
(
image_classification
.
ImageClassificationTask
):
"""A task for image classification with QAT."""
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds classification model with QAT."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
super
(
ImageClassificationTask
,
self
).
build_model
()
if
self
.
task_config
.
quantization
:
model
=
factory
.
build_qat_classification_model
(
model
,
self
.
task_config
.
quantization
,
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
return
model
official/projects/qat/vision/tasks/image_classification_test.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for image classification task."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.projects.qat.vision.tasks
import
image_classification
as
img_cls_task
from
official.vision
import
beta
class
ImageClassificationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'resnet_imagenet_qat'
),
(
'mobilenet_imagenet_qat'
))
def
test_task
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
.
task
.
train_data
.
global_batch_size
=
2
task
=
img_cls_task
.
ImageClassificationTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
for
metric
in
metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
self
.
assertIn
(
'loss'
,
logs
)
self
.
assertIn
(
'accuracy'
,
logs
)
self
.
assertIn
(
'top_5_accuracy'
,
logs
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
for
metric
in
metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
self
.
assertIn
(
'loss'
,
logs
)
self
.
assertIn
(
'accuracy'
,
logs
)
self
.
assertIn
(
'top_5_accuracy'
,
logs
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/qat/vision/tasks/retinanet.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet task definition."""
import
tensorflow
as
tf
from
official.core
import
task_factory
from
official.projects.qat.vision.configs
import
retinanet
as
exp_cfg
from
official.projects.qat.vision.modeling
import
factory
from
official.vision.beta.tasks
import
retinanet
@
task_factory
.
register_task_cls
(
exp_cfg
.
RetinaNetTask
)
class
RetinaNetTask
(
retinanet
.
RetinaNetTask
):
"""A task for RetinaNet object detection with QAT."""
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds RetinaNet model with QAT."""
model
=
super
(
RetinaNetTask
,
self
).
build_model
()
if
self
.
task_config
.
quantization
:
model
=
factory
.
build_qat_retinanet
(
model
,
self
.
task_config
.
quantization
,
model_config
=
self
.
task_config
.
model
)
return
model
official/projects/qat/vision/tasks/retinanet_test.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for RetinaNet task."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.projects.qat.vision.tasks
import
retinanet
from
official.vision
import
beta
from
official.vision.beta.configs
import
retinanet
as
exp_cfg
class
RetinaNetTaskTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
'retinanet_spinenet_mobile_coco_qat'
,
True
),
(
'retinanet_spinenet_mobile_coco_qat'
,
False
),
)
def
test_retinanet_task
(
self
,
test_config
,
is_training
):
"""RetinaNet task test for training and val using toy configs."""
config
=
exp_factory
.
get_exp_config
(
test_config
)
# modify config to suit local testing
config
.
task
.
model
.
input_size
=
[
128
,
128
,
3
]
config
.
trainer
.
steps_per_loop
=
1
config
.
task
.
train_data
.
global_batch_size
=
1
config
.
task
.
validation_data
.
global_batch_size
=
1
config
.
task
.
train_data
.
shuffle_buffer_size
=
2
config
.
task
.
validation_data
.
shuffle_buffer_size
=
2
config
.
train_steps
=
1
task
=
retinanet
.
RetinaNetTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
(
training
=
is_training
)
strategy
=
tf
.
distribute
.
get_strategy
()
data_config
=
config
.
task
.
train_data
if
is_training
else
config
.
task
.
validation_data
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
data_config
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
is_training
:
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
else
:
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/qat/vision/tasks/semantic_segmentation.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Semantic segmentation task definition."""
import
tensorflow
as
tf
from
official.core
import
task_factory
from
official.projects.qat.vision.configs
import
semantic_segmentation
as
exp_cfg
from
official.projects.qat.vision.modeling
import
factory
from
official.vision.beta.tasks
import
semantic_segmentation
@
task_factory
.
register_task_cls
(
exp_cfg
.
SemanticSegmentationTask
)
class
SemanticSegmentationTask
(
semantic_segmentation
.
SemanticSegmentationTask
):
"""A task for semantic segmentation with QAT."""
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds semantic segmentation model with QAT."""
model
=
super
().
build_model
()
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
if
self
.
task_config
.
quantization
:
model
=
factory
.
build_qat_segmentation_model
(
model
,
self
.
task_config
.
quantization
,
input_specs
)
return
model
official/projects/qat/vision/tasks/semantic_segmentation_test.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for semantic segmentation task."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.projects.qat.vision.tasks
import
semantic_segmentation
from
official.vision
import
beta
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
class
SemanticSegmentationTaskTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
'mnv2_deeplabv3_pascal_qat'
,
True
),
(
'mnv2_deeplabv3_pascal_qat'
,
False
),
)
def
test_semantic_segmentation_task
(
self
,
test_config
,
is_training
):
"""Semantic segmentation task test for training and val using toy configs."""
config
=
exp_factory
.
get_exp_config
(
test_config
)
# modify config to suit local testing
config
.
task
.
model
.
input_size
=
[
512
,
512
,
3
]
config
.
trainer
.
steps_per_loop
=
1
config
.
task
.
train_data
.
global_batch_size
=
1
config
.
task
.
validation_data
.
global_batch_size
=
1
config
.
task
.
train_data
.
shuffle_buffer_size
=
2
config
.
task
.
validation_data
.
shuffle_buffer_size
=
2
config
.
train_steps
=
1
config
.
task
.
model
.
decoder
.
aspp
.
output_tensor
=
True
task
=
semantic_segmentation
.
SemanticSegmentationTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
(
training
=
is_training
)
strategy
=
tf
.
distribute
.
get_strategy
()
data_config
=
config
.
task
.
train_data
if
is_training
else
config
.
task
.
validation_data
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
data_config
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
is_training
:
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
else
:
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/qat/vision/train.py
0 → 100644
View file @
ce2aea2f
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including QAT configs.."""
from
absl
import
app
from
official.common
import
flags
as
tfm_flags
from
official.projects.qat.vision
import
registry_imports
# pylint: disable=unused-import
from
official.vision.beta
import
train
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
train
.
main
)
official/vision/beta/configs/semantic_segmentation.py
View file @
ce2aea2f
...
...
@@ -148,7 +148,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig:
# PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES
=
10582
PASCAL_VAL_EXAMPLES
=
1449
PASCAL_INPUT_PATH_BASE
=
'pascal_voc_seg'
PASCAL_INPUT_PATH_BASE
=
'
gs://**/
pascal_voc_seg'
@
exp_factory
.
register_config_factory
(
'seg_deeplabv3_pascal'
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment