Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cf2326ca
Commit
cf2326ca
authored
Aug 30, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Aug 30, 2021
Browse files
Implement TFLite conversion to convert a SavedModel to TFLite with PTQ.
PiperOrigin-RevId: 393871715
parent
002ec22b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
288 additions
and
0 deletions
+288
-0
official/vision/beta/serving/export_base.py
official/vision/beta/serving/export_base.py
+11
-0
official/vision/beta/serving/export_tflite.py
official/vision/beta/serving/export_tflite.py
+89
-0
official/vision/beta/serving/export_tflite_lib.py
official/vision/beta/serving/export_tflite_lib.py
+113
-0
official/vision/beta/serving/export_tflite_lib_test.py
official/vision/beta/serving/export_tflite_lib_test.py
+75
-0
No files found.
official/vision/beta/serving/export_base.py
View file @
cf2326ca
...
...
@@ -103,6 +103,10 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
inputs
)
@
tf
.
function
def
inference_for_tflite
(
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
inputs
)
@
tf
.
function
def
inference_from_image_bytes
(
self
,
inputs
:
tf
.
Tensor
):
with
tf
.
device
(
'cpu:0'
):
...
...
@@ -174,6 +178,13 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
signatures
[
def_name
]
=
self
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
elif
key
==
'tflite'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
self
.
_num_channels
],
dtype
=
tf
.
float32
)
signatures
[
def_name
]
=
self
.
inference_for_tflite
.
get_concrete_function
(
input_signature
)
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
return
signatures
official/vision/beta/serving/export_tflite.py
0 → 100644
View file @
cf2326ca
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binary to convert a saved model to tflite model."""
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.serving
import
export_tflite_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
,
required
=
True
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
''
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_string
(
'saved_model_dir'
,
None
,
'The directory to the saved model.'
,
required
=
True
)
flags
.
DEFINE_string
(
'tflite_path'
,
None
,
'The path to the output tflite model.'
,
required
=
True
)
flags
.
DEFINE_string
(
'quant_type'
,
default
=
None
,
help
=
'Post training quantization type. Support `int8`, `int8_full`, '
'`fp16`, and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.'
)
flags
.
DEFINE_integer
(
'calibration_steps'
,
500
,
'The number of calibration steps for integer model.'
)
def
main
(
_
)
->
None
:
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
if
FLAGS
.
config_file
is
not
None
:
for
config_file
in
FLAGS
.
config_file
:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
logging
.
info
(
'Converting SavedModel from %s to TFLite model...'
,
FLAGS
.
saved_model_dir
)
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
FLAGS
.
saved_model_dir
,
quant_type
=
FLAGS
.
quant_type
,
params
=
params
,
calibration_steps
=
FLAGS
.
calibration_steps
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
tflite_path
,
'wb'
)
as
fw
:
fw
.
write
(
tflite_model
)
logging
.
info
(
'TFLite model converted and saved to %s.'
,
FLAGS
.
tflite_path
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/vision/beta/serving/export_tflite_lib.py
0 → 100644
View file @
cf2326ca
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import
functools
from
typing
import
Iterator
,
List
,
Optional
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.vision.beta
import
configs
from
official.vision.beta.tasks
import
image_classification
as
img_cls_task
def
create_representative_dataset
(
params
:
cfg
.
ExperimentConfig
)
->
tf
.
data
.
Dataset
:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
Returns:
A tf.data.Dataset instance.
Raises:
ValueError: If task is not supported.
"""
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
task
=
img_cls_task
.
ImageClassificationTask
(
params
.
task
)
else
:
raise
ValueError
(
'Task {} not supported.'
.
format
(
type
(
params
.
task
)))
# Ensure batch size is 1 for TFLite model.
params
.
task
.
train_data
.
global_batch_size
=
1
params
.
task
.
train_data
.
dtype
=
'float32'
return
task
.
build_inputs
(
params
=
params
.
task
.
train_data
)
def
representative_dataset
(
params
:
cfg
.
ExperimentConfig
,
calibration_steps
:
int
=
2000
)
->
Iterator
[
List
[
tf
.
Tensor
]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset
=
create_representative_dataset
(
params
=
params
)
for
image
,
_
in
dataset
.
take
(
calibration_steps
):
# Skip images that do not have 3 channels.
if
image
.
shape
[
-
1
]
!=
3
:
continue
yield
[
image
]
def
convert_tflite_model
(
saved_model_dir
:
str
,
quant_type
:
Optional
[
str
]
=
None
,
params
:
Optional
[
cfg
.
ExperimentConfig
]
=
None
,
calibration_steps
:
Optional
[
int
]
=
2000
)
->
bytes
:
"""Converts and returns a TFLite model.
Args:
saved_model_dir: The directory to the SavedModel.
quant_type: The post training quantization (PTQ) method. It can be one of
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
calibration_steps: The steps to do calibration.
Returns:
A converted TFLite model with optional PTQ.
Raises:
ValueError: If `representative_dataset_path` is not present if integer
quantization is requested.
"""
converter
=
tf
.
lite
.
TFLiteConverter
.
from_saved_model
(
saved_model_dir
)
if
quant_type
:
if
quant_type
.
startswith
(
'int8'
):
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
converter
.
representative_dataset
=
functools
.
partial
(
representative_dataset
,
params
=
params
,
calibration_steps
=
calibration_steps
)
if
quant_type
==
'int8_full'
:
converter
.
target_spec
.
supported_ops
=
[
tf
.
lite
.
OpsSet
.
TFLITE_BUILTINS_INT8
]
converter
.
inference_input_type
=
tf
.
uint8
# or tf.int8
converter
.
inference_output_type
=
tf
.
uint8
# or tf.int8
elif
quant_type
==
'fp16'
:
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
converter
.
target_spec
.
supported_types
=
[
tf
.
float16
]
elif
quant_type
==
'default'
:
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
return
converter
.
convert
()
official/vision/beta/serving/export_tflite_lib_test.py
0 → 100644
View file @
cf2326ca
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_tflite_lib."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.serving
import
export_tflite_lib
from
official.vision.beta.serving
import
image_classification
as
image_classification_serving
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'test.tfrecord'
)
self
.
_create_test_tfrecord
(
num_samples
=
50
)
def
_create_test_tfrecord
(
self
,
num_samples
):
tfexample_utils
.
dump_to_tfrecord
(
self
.
_test_tfrecord_file
,
[
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
256
,
image_width
=
256
))
for
_
in
range
(
num_samples
)
])
def
_export_from_module
(
self
,
module
,
input_type
,
saved_model_dir
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
saved_model_dir
,
signatures
=
signatures
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'mobilenet_imagenet'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
],
input_image_size
=
[[
224
,
224
]]))
def
test_export_tflite
(
self
,
experiment
,
quant_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
_test_tfrecord_file
temp_dir
=
self
.
get_temp_dir
()
module
=
image_classification_serving
.
ClassificationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
20
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment