Commit 083ee92f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Improve support in TFLite model conversion.

PiperOrigin-RevId: 472790049
parent a3f34adb
...@@ -44,12 +44,12 @@ from official.vision.serving import export_tflite_lib ...@@ -44,12 +44,12 @@ from official.vision.serving import export_tflite_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string( _EXPERIMENT = flags.DEFINE_string(
'experiment', 'experiment',
None, None,
'experiment type, e.g. retinanet_resnetfpn_coco', 'experiment type, e.g. retinanet_resnetfpn_coco',
required=True) required=True)
flags.DEFINE_multi_string( _CONFIG_FILE = flags.DEFINE_multi_string(
'config_file', 'config_file',
default='', default='',
help='YAML/JSON files which specifies overrides. The override order ' help='YAML/JSON files which specifies overrides. The override order '
...@@ -58,15 +58,15 @@ flags.DEFINE_multi_string( ...@@ -58,15 +58,15 @@ flags.DEFINE_multi_string(
'specified in Python. If the same parameter is specified in both ' 'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used ' '`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.') 'first, followed by params_override.')
flags.DEFINE_string( _PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override', '', 'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden' 'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.') ' on top of `config_file` template.')
flags.DEFINE_string( _SAVED_MODEL_DIR = flags.DEFINE_string(
'saved_model_dir', None, 'The directory to the saved model.', required=True) 'saved_model_dir', None, 'The directory to the saved model.', required=True)
flags.DEFINE_string( _TFLITE_PATH = flags.DEFINE_string(
'tflite_path', None, 'The path to the output tflite model.', required=True) 'tflite_path', None, 'The path to the output tflite model.', required=True)
flags.DEFINE_string( _QUANT_TYPE = flags.DEFINE_string(
'quant_type', 'quant_type',
default=None, default=None,
help='Post training quantization type. Support `int8_fallback`, ' help='Post training quantization type. Support `int8_fallback`, '
...@@ -74,35 +74,47 @@ flags.DEFINE_string( ...@@ -74,35 +74,47 @@ flags.DEFINE_string(
'`int8_full_int8_io` and `default`. See ' '`int8_full_int8_io` and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization ' 'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.') 'for more details.')
flags.DEFINE_integer('calibration_steps', 500, _CALIBRATION_STEPS = flags.DEFINE_integer(
'The number of calibration steps for integer model.') 'calibration_steps', 500,
'The number of calibration steps for integer model.')
_DENYLISTED_OPS = flags.DEFINE_string(
'denylisted_ops', '', 'The comma-separated string of ops '
'that are excluded from integer quantization. The name of '
'ops should be all capital letters, such as CAST or GREATER.'
'This is useful to exclude certains ops that affects quality or latency. '
'Valid ops that should not be included are quantization friendly ops, such '
'as CONV_2D, DEPTHWISE_CONV_2D, FULLY_CONNECTED, etc.')
def main(_) -> None: def main(_) -> None:
params = exp_factory.get_exp_config(FLAGS.experiment) params = exp_factory.get_exp_config(_EXPERIMENT.value)
if FLAGS.config_file is not None: if _CONFIG_FILE.value is not None:
for config_file in FLAGS.config_file: for config_file in _CONFIG_FILE.value:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, config_file, is_strict=True) params, config_file, is_strict=True)
if FLAGS.params_override: if _PARAMS_OVERRIDE.value:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True) params, _PARAMS_OVERRIDE.value, is_strict=True)
params.validate() params.validate()
params.lock() params.lock()
logging.info('Converting SavedModel from %s to TFLite model...', logging.info('Converting SavedModel from %s to TFLite model...',
FLAGS.saved_model_dir) _SAVED_MODEL_DIR.value)
if _DENYLISTED_OPS.value:
denylisted_ops = list(_DENYLISTED_OPS.value.split(','))
tflite_model = export_tflite_lib.convert_tflite_model( tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=FLAGS.saved_model_dir, saved_model_dir=_SAVED_MODEL_DIR.value,
quant_type=FLAGS.quant_type, quant_type=_QUANT_TYPE.value,
params=params, params=params,
calibration_steps=FLAGS.calibration_steps) calibration_steps=_CALIBRATION_STEPS.value,
denylisted_ops=denylisted_ops)
with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw: with tf.io.gfile.GFile(_TFLITE_PATH.value, 'wb') as fw:
fw.write(tflite_model) fw.write(tflite_model)
logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path) logging.info('TFLite model converted and saved to %s.', _TFLITE_PATH.value)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,17 +19,21 @@ from typing import Iterator, List, Optional ...@@ -19,17 +19,21 @@ from typing import Iterator, List, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.vision import configs from official.vision import configs
from official.vision import tasks from official.vision import tasks
def create_representative_dataset( def create_representative_dataset(
params: cfg.ExperimentConfig) -> tf.data.Dataset: params: cfg.ExperimentConfig,
task: Optional[base_task.Task] = None) -> tf.data.Dataset:
"""Creates a tf.data.Dataset to load images for representative dataset. """Creates a tf.data.Dataset to load images for representative dataset.
Args: Args:
params: An ExperimentConfig. params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
Returns: Returns:
A tf.data.Dataset instance. A tf.data.Dataset instance.
...@@ -37,19 +41,20 @@ def create_representative_dataset( ...@@ -37,19 +41,20 @@ def create_representative_dataset(
Raises: Raises:
ValueError: If task is not supported. ValueError: If task is not supported.
""" """
if isinstance(params.task, if task is None:
configs.image_classification.ImageClassificationTask): if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask): task = tasks.image_classification.ImageClassificationTask(params.task)
task = tasks.retinanet.RetinaNetTask(params.task) elif isinstance(params.task, configs.retinanet.RetinaNetTask):
elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask): task = tasks.retinanet.RetinaNetTask(params.task)
task = tasks.maskrcnn.MaskRCNNTask(params.task) elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
elif isinstance(params.task, task = tasks.maskrcnn.MaskRCNNTask(params.task)
configs.semantic_segmentation.SemanticSegmentationTask): elif isinstance(params.task,
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task) configs.semantic_segmentation.SemanticSegmentationTask):
else: task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
raise ValueError('Task {} not supported.'.format(type(params.task))) else:
raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model. # Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1 params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32' params.task.train_data.dtype = 'float32'
...@@ -59,17 +64,20 @@ def create_representative_dataset( ...@@ -59,17 +64,20 @@ def create_representative_dataset(
def representative_dataset( def representative_dataset(
params: cfg.ExperimentConfig, params: cfg.ExperimentConfig,
task: Optional[base_task.Task] = None,
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]: calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
""""Creates representative dataset for input calibration. """"Creates representative dataset for input calibration.
Args: Args:
params: An ExperimentConfig. params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration. calibration_steps: The steps to do calibration.
Yields: Yields:
An input image tensor. An input image tensor.
""" """
dataset = create_representative_dataset(params=params) dataset = create_representative_dataset(params=params, task=task)
for image, _ in dataset.take(calibration_steps): for image, _ in dataset.take(calibration_steps):
# Skip images that do not have 3 channels. # Skip images that do not have 3 channels.
if image.shape[-1] != 3: if image.shape[-1] != 3:
...@@ -80,7 +88,9 @@ def representative_dataset( ...@@ -80,7 +88,9 @@ def representative_dataset(
def convert_tflite_model(saved_model_dir: str, def convert_tflite_model(saved_model_dir: str,
quant_type: Optional[str] = None, quant_type: Optional[str] = None,
params: Optional[cfg.ExperimentConfig] = None, params: Optional[cfg.ExperimentConfig] = None,
calibration_steps: Optional[int] = 2000) -> bytes: task: Optional[base_task.Task] = None,
calibration_steps: Optional[int] = 2000,
denylisted_ops: Optional[list[str]] = None) -> bytes:
"""Converts and returns a TFLite model. """Converts and returns a TFLite model.
Args: Args:
...@@ -90,7 +100,11 @@ def convert_tflite_model(saved_model_dir: str, ...@@ -90,7 +100,11 @@ def convert_tflite_model(saved_model_dir: str,
fallback), `int8_full` (integer only) and None (no quantization). fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization. do calibration for integer quantization.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration. calibration_steps: The steps to do calibration.
denylisted_ops: A list of strings containing ops that are excluded from
integer quantization.
Returns: Returns:
A converted TFLite model with optional PTQ. A converted TFLite model with optional PTQ.
...@@ -106,6 +120,7 @@ def convert_tflite_model(saved_model_dir: str, ...@@ -106,6 +120,7 @@ def convert_tflite_model(saved_model_dir: str,
converter.representative_dataset = functools.partial( converter.representative_dataset = functools.partial(
representative_dataset, representative_dataset,
params=params, params=params,
task=task,
calibration_steps=calibration_steps) calibration_steps=calibration_steps)
if quant_type.startswith('int8_full'): if quant_type.startswith('int8_full'):
converter.target_spec.supported_ops = [ converter.target_spec.supported_ops = [
...@@ -117,6 +132,20 @@ def convert_tflite_model(saved_model_dir: str, ...@@ -117,6 +132,20 @@ def convert_tflite_model(saved_model_dir: str,
if quant_type == 'int8_full_int8_io': if quant_type == 'int8_full_int8_io':
converter.inference_input_type = tf.int8 converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8 converter.inference_output_type = tf.int8
if denylisted_ops:
debug_options = tf.lite.experimental.QuantizationDebugOptions(
denylisted_ops=denylisted_ops)
debugger = tf.lite.experimental.QuantizationDebugger(
converter=converter,
debug_dataset=functools.partial(
representative_dataset,
params=params,
calibration_steps=calibration_steps),
debug_options=debug_options)
debugger.run()
return debugger.get_nondebug_quantized_model()
elif quant_type == 'fp16': elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16] converter.target_spec.supported_types = [tf.float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment