Commit cf2326ca authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Implement TFLite conversion to convert a SavedModel to TFLite with PTQ.

PiperOrigin-RevId: 393871715
parent 002ec22b
......@@ -103,6 +103,10 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_for_tflite(self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_from_image_bytes(self, inputs: tf.Tensor):
with tf.device('cpu:0'):
......@@ -174,6 +178,13 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
elif key == 'tflite':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size +
[self._num_channels],
dtype=tf.float32)
signatures[def_name] = self.inference_for_tflite.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binary to convert a saved model to tflite model."""
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import export_tflite_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment',
None,
'experiment type, e.g. retinanet_resnetfpn_coco',
required=True)
flags.DEFINE_multi_string(
'config_file',
default='',
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_string(
'saved_model_dir', None, 'The directory to the saved model.', required=True)
flags.DEFINE_string(
'tflite_path', None, 'The path to the output tflite model.', required=True)
flags.DEFINE_string(
'quant_type',
default=None,
help='Post training quantization type. Support `int8`, `int8_full`, '
'`fp16`, and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.')
flags.DEFINE_integer('calibration_steps', 500,
'The number of calibration steps for integer model.')
def main(_) -> None:
params = exp_factory.get_exp_config(FLAGS.experiment)
if FLAGS.config_file is not None:
for config_file in FLAGS.config_file:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
logging.info('Converting SavedModel from %s to TFLite model...',
FLAGS.saved_model_dir)
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=FLAGS.saved_model_dir,
quant_type=FLAGS.quant_type,
params=params,
calibration_steps=FLAGS.calibration_steps)
with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw:
fw.write(tflite_model)
logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
if __name__ == '__main__':
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import functools
from typing import Iterator, List, Optional
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision.beta import configs
from official.vision.beta.tasks import image_classification as img_cls_task
def create_representative_dataset(
params: cfg.ExperimentConfig) -> tf.data.Dataset:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
Returns:
A tf.data.Dataset instance.
Raises:
ValueError: If task is not supported.
"""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
task = img_cls_task.ImageClassificationTask(params.task)
else:
raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32'
return task.build_inputs(params=params.task.train_data)
def representative_dataset(
params: cfg.ExperimentConfig,
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset = create_representative_dataset(params=params)
for image, _ in dataset.take(calibration_steps):
# Skip images that do not have 3 channels.
if image.shape[-1] != 3:
continue
yield [image]
def convert_tflite_model(saved_model_dir: str,
quant_type: Optional[str] = None,
params: Optional[cfg.ExperimentConfig] = None,
calibration_steps: Optional[int] = 2000) -> bytes:
"""Converts and returns a TFLite model.
Args:
saved_model_dir: The directory to the SavedModel.
quant_type: The post training quantization (PTQ) method. It can be one of
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
calibration_steps: The steps to do calibration.
Returns:
A converted TFLite model with optional PTQ.
Raises:
ValueError: If `representative_dataset_path` is not present if integer
quantization is requested.
"""
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
if quant_type:
if quant_type.startswith('int8'):
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = functools.partial(
representative_dataset,
params=params,
calibration_steps=calibration_steps)
if quant_type == 'int8_full':
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
elif quant_type == 'default':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
return converter.convert()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_tflite_lib."""
import os
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import export_tflite_lib
from official.vision.beta.serving import image_classification as image_classification_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._test_tfrecord_file = os.path.join(self.get_temp_dir(),
'test.tfrecord')
self._create_test_tfrecord(num_samples=50)
def _create_test_tfrecord(self, num_samples):
tfexample_utils.dump_to_tfrecord(self._test_tfrecord_file, [
tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=256, image_width=256)) for _ in range(num_samples)
])
def _export_from_module(self, module, input_type, saved_model_dir):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, saved_model_dir, signatures=signatures)
@combinations.generate(
combinations.combine(
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'],
input_image_size=[[224, 224]]))
def test_export_tflite(self, experiment, quant_type, input_image_size):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file
temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=20)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment