"tests/compute/test_index.py" did not exist on "eafcb7e7f55d385099a9289b275e8371897edb9f"
Commit 083ee92f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Improve support in TFLite model conversion.

PiperOrigin-RevId: 472790049
parent a3f34adb
......@@ -44,12 +44,12 @@ from official.vision.serving import export_tflite_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
_EXPERIMENT = flags.DEFINE_string(
'experiment',
None,
'experiment type, e.g. retinanet_resnetfpn_coco',
required=True)
flags.DEFINE_multi_string(
_CONFIG_FILE = flags.DEFINE_multi_string(
'config_file',
default='',
help='YAML/JSON files which specifies overrides. The override order '
......@@ -58,15 +58,15 @@ flags.DEFINE_multi_string(
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
_PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_string(
_SAVED_MODEL_DIR = flags.DEFINE_string(
'saved_model_dir', None, 'The directory to the saved model.', required=True)
flags.DEFINE_string(
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path', None, 'The path to the output tflite model.', required=True)
flags.DEFINE_string(
_QUANT_TYPE = flags.DEFINE_string(
'quant_type',
default=None,
help='Post training quantization type. Support `int8_fallback`, '
......@@ -74,35 +74,47 @@ flags.DEFINE_string(
'`int8_full_int8_io` and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.')
flags.DEFINE_integer('calibration_steps', 500,
_CALIBRATION_STEPS = flags.DEFINE_integer(
'calibration_steps', 500,
'The number of calibration steps for integer model.')
_DENYLISTED_OPS = flags.DEFINE_string(
'denylisted_ops', '', 'The comma-separated string of ops '
'that are excluded from integer quantization. The name of '
'ops should be all capital letters, such as CAST or GREATER.'
'This is useful to exclude certains ops that affects quality or latency. '
'Valid ops that should not be included are quantization friendly ops, such '
'as CONV_2D, DEPTHWISE_CONV_2D, FULLY_CONNECTED, etc.')
def main(_) -> None:
params = exp_factory.get_exp_config(FLAGS.experiment)
if FLAGS.config_file is not None:
for config_file in FLAGS.config_file:
params = exp_factory.get_exp_config(_EXPERIMENT.value)
if _CONFIG_FILE.value is not None:
for config_file in _CONFIG_FILE.value:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
if _PARAMS_OVERRIDE.value:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params, _PARAMS_OVERRIDE.value, is_strict=True)
params.validate()
params.lock()
logging.info('Converting SavedModel from %s to TFLite model...',
FLAGS.saved_model_dir)
_SAVED_MODEL_DIR.value)
if _DENYLISTED_OPS.value:
denylisted_ops = list(_DENYLISTED_OPS.value.split(','))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=FLAGS.saved_model_dir,
quant_type=FLAGS.quant_type,
saved_model_dir=_SAVED_MODEL_DIR.value,
quant_type=_QUANT_TYPE.value,
params=params,
calibration_steps=FLAGS.calibration_steps)
calibration_steps=_CALIBRATION_STEPS.value,
denylisted_ops=denylisted_ops)
with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw:
with tf.io.gfile.GFile(_TFLITE_PATH.value, 'wb') as fw:
fw.write(tflite_model)
logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
logging.info('TFLite model converted and saved to %s.', _TFLITE_PATH.value)
if __name__ == '__main__':
......
......@@ -19,17 +19,21 @@ from typing import Iterator, List, Optional
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision import tasks
def create_representative_dataset(
params: cfg.ExperimentConfig) -> tf.data.Dataset:
params: cfg.ExperimentConfig,
task: Optional[base_task.Task] = None) -> tf.data.Dataset:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
Returns:
A tf.data.Dataset instance.
......@@ -37,6 +41,7 @@ def create_representative_dataset(
Raises:
ValueError: If task is not supported.
"""
if task is None:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
......@@ -59,17 +64,20 @@ def create_representative_dataset(
def representative_dataset(
params: cfg.ExperimentConfig,
task: Optional[base_task.Task] = None,
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset = create_representative_dataset(params=params)
dataset = create_representative_dataset(params=params, task=task)
for image, _ in dataset.take(calibration_steps):
# Skip images that do not have 3 channels.
if image.shape[-1] != 3:
......@@ -80,7 +88,9 @@ def representative_dataset(
def convert_tflite_model(saved_model_dir: str,
quant_type: Optional[str] = None,
params: Optional[cfg.ExperimentConfig] = None,
calibration_steps: Optional[int] = 2000) -> bytes:
task: Optional[base_task.Task] = None,
calibration_steps: Optional[int] = 2000,
denylisted_ops: Optional[list[str]] = None) -> bytes:
"""Converts and returns a TFLite model.
Args:
......@@ -90,7 +100,11 @@ def convert_tflite_model(saved_model_dir: str,
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration.
denylisted_ops: A list of strings containing ops that are excluded from
integer quantization.
Returns:
A converted TFLite model with optional PTQ.
......@@ -106,6 +120,7 @@ def convert_tflite_model(saved_model_dir: str,
converter.representative_dataset = functools.partial(
representative_dataset,
params=params,
task=task,
calibration_steps=calibration_steps)
if quant_type.startswith('int8_full'):
converter.target_spec.supported_ops = [
......@@ -117,6 +132,20 @@ def convert_tflite_model(saved_model_dir: str,
if quant_type == 'int8_full_int8_io':
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
if denylisted_ops:
debug_options = tf.lite.experimental.QuantizationDebugOptions(
denylisted_ops=denylisted_ops)
debugger = tf.lite.experimental.QuantizationDebugger(
converter=converter,
debug_dataset=functools.partial(
representative_dataset,
params=params,
calibration_steps=calibration_steps),
debug_options=debug_options)
debugger.run()
return debugger.get_nondebug_quantized_model()
elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment