Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2e948ccb
Commit
2e948ccb
authored
Sep 07, 2022
by
Fan Yang
Committed by
A. Unique TensorFlower
Sep 07, 2022
Browse files
Improve support in TFLite model conversion.
PiperOrigin-RevId: 472790049
parent
e6f465eb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
35 deletions
+76
-35
official/vision/serving/export_tflite.py
official/vision/serving/export_tflite.py
+31
-19
official/vision/serving/export_tflite_lib.py
official/vision/serving/export_tflite_lib.py
+45
-16
No files found.
official/vision/serving/export_tflite.py
View file @
2e948ccb
...
...
@@ -44,12 +44,12 @@ from official.vision.serving import export_tflite_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
_EXPERIMENT
=
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
,
required
=
True
)
flags
.
DEFINE_multi_string
(
_CONFIG_FILE
=
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
''
,
help
=
'YAML/JSON files which specifies overrides. The override order '
...
...
@@ -58,15 +58,15 @@ flags.DEFINE_multi_string(
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
_PARAMS_OVERRIDE
=
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_string
(
_SAVED_MODEL_DIR
=
flags
.
DEFINE_string
(
'saved_model_dir'
,
None
,
'The directory to the saved model.'
,
required
=
True
)
flags
.
DEFINE_string
(
_TFLITE_PATH
=
flags
.
DEFINE_string
(
'tflite_path'
,
None
,
'The path to the output tflite model.'
,
required
=
True
)
flags
.
DEFINE_string
(
_QUANT_TYPE
=
flags
.
DEFINE_string
(
'quant_type'
,
default
=
None
,
help
=
'Post training quantization type. Support `int8_fallback`, '
...
...
@@ -74,35 +74,47 @@ flags.DEFINE_string(
'`int8_full_int8_io` and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.'
)
flags
.
DEFINE_integer
(
'calibration_steps'
,
500
,
_CALIBRATION_STEPS
=
flags
.
DEFINE_integer
(
'calibration_steps'
,
500
,
'The number of calibration steps for integer model.'
)
_DENYLISTED_OPS
=
flags
.
DEFINE_string
(
'denylisted_ops'
,
''
,
'The comma-separated string of ops '
'that are excluded from integer quantization. The name of '
'ops should be all capital letters, such as CAST or GREATER.'
'This is useful to exclude certains ops that affects quality or latency. '
'Valid ops that should not be included are quantization friendly ops, such '
'as CONV_2D, DEPTHWISE_CONV_2D, FULLY_CONNECTED, etc.'
)
def
main
(
_
)
->
None
:
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
if
FLAGS
.
config_fil
e
is
not
None
:
for
config_file
in
FLAGS
.
config_fil
e
:
params
=
exp_factory
.
get_exp_config
(
_EXPERIMENT
.
value
)
if
_CONFIG_FILE
.
valu
e
is
not
None
:
for
config_file
in
_CONFIG_FILE
.
valu
e
:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_overrid
e
:
if
_PARAMS_OVERRIDE
.
valu
e
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_overrid
e
,
is_strict
=
True
)
params
,
_PARAMS_OVERRIDE
.
valu
e
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
logging
.
info
(
'Converting SavedModel from %s to TFLite model...'
,
FLAGS
.
saved_model_dir
)
_SAVED_MODEL_DIR
.
value
)
if
_DENYLISTED_OPS
.
value
:
denylisted_ops
=
list
(
_DENYLISTED_OPS
.
value
.
split
(
','
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
FLAGS
.
saved_model_dir
,
quant_type
=
FLAGS
.
quant_typ
e
,
saved_model_dir
=
_SAVED_MODEL_DIR
.
value
,
quant_type
=
_QUANT_TYPE
.
valu
e
,
params
=
params
,
calibration_steps
=
FLAGS
.
calibration_steps
)
calibration_steps
=
_CALIBRATION_STEPS
.
value
,
denylisted_ops
=
denylisted_ops
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
tflite_path
,
'wb'
)
as
fw
:
with
tf
.
io
.
gfile
.
GFile
(
_TFLITE_PATH
.
value
,
'wb'
)
as
fw
:
fw
.
write
(
tflite_model
)
logging
.
info
(
'TFLite model converted and saved to %s.'
,
FLAGS
.
tflite_path
)
logging
.
info
(
'TFLite model converted and saved to %s.'
,
_TFLITE_PATH
.
value
)
if
__name__
==
'__main__'
:
...
...
official/vision/serving/export_tflite_lib.py
View file @
2e948ccb
...
...
@@ -19,17 +19,21 @@ from typing import Iterator, List, Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.vision
import
configs
from
official.vision
import
tasks
def
create_representative_dataset
(
params
:
cfg
.
ExperimentConfig
)
->
tf
.
data
.
Dataset
:
params
:
cfg
.
ExperimentConfig
,
task
:
Optional
[
base_task
.
Task
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
Returns:
A tf.data.Dataset instance.
...
...
@@ -37,6 +41,7 @@ def create_representative_dataset(
Raises:
ValueError: If task is not supported.
"""
if
task
is
None
:
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
...
...
@@ -59,17 +64,20 @@ def create_representative_dataset(
def
representative_dataset
(
params
:
cfg
.
ExperimentConfig
,
task
:
Optional
[
base_task
.
Task
]
=
None
,
calibration_steps
:
int
=
2000
)
->
Iterator
[
List
[
tf
.
Tensor
]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset
=
create_representative_dataset
(
params
=
params
)
dataset
=
create_representative_dataset
(
params
=
params
,
task
=
task
)
for
image
,
_
in
dataset
.
take
(
calibration_steps
):
# Skip images that do not have 3 channels.
if
image
.
shape
[
-
1
]
!=
3
:
...
...
@@ -80,7 +88,9 @@ def representative_dataset(
def
convert_tflite_model
(
saved_model_dir
:
str
,
quant_type
:
Optional
[
str
]
=
None
,
params
:
Optional
[
cfg
.
ExperimentConfig
]
=
None
,
calibration_steps
:
Optional
[
int
]
=
2000
)
->
bytes
:
task
:
Optional
[
base_task
.
Task
]
=
None
,
calibration_steps
:
Optional
[
int
]
=
2000
,
denylisted_ops
:
Optional
[
list
[
str
]]
=
None
)
->
bytes
:
"""Converts and returns a TFLite model.
Args:
...
...
@@ -90,7 +100,11 @@ def convert_tflite_model(saved_model_dir: str,
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration.
denylisted_ops: A list of strings containing ops that are excluded from
integer quantization.
Returns:
A converted TFLite model with optional PTQ.
...
...
@@ -106,6 +120,7 @@ def convert_tflite_model(saved_model_dir: str,
converter
.
representative_dataset
=
functools
.
partial
(
representative_dataset
,
params
=
params
,
task
=
task
,
calibration_steps
=
calibration_steps
)
if
quant_type
.
startswith
(
'int8_full'
):
converter
.
target_spec
.
supported_ops
=
[
...
...
@@ -117,6 +132,20 @@ def convert_tflite_model(saved_model_dir: str,
if
quant_type
==
'int8_full_int8_io'
:
converter
.
inference_input_type
=
tf
.
int8
converter
.
inference_output_type
=
tf
.
int8
if
denylisted_ops
:
debug_options
=
tf
.
lite
.
experimental
.
QuantizationDebugOptions
(
denylisted_ops
=
denylisted_ops
)
debugger
=
tf
.
lite
.
experimental
.
QuantizationDebugger
(
converter
=
converter
,
debug_dataset
=
functools
.
partial
(
representative_dataset
,
params
=
params
,
calibration_steps
=
calibration_steps
),
debug_options
=
debug_options
)
debugger
.
run
()
return
debugger
.
get_nondebug_quantized_model
()
elif
quant_type
==
'fp16'
:
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
converter
.
target_spec
.
supported_types
=
[
tf
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment