Commit 83911092 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add an option to convert a full INT8 TFLite model with FP32 input and output.

PiperOrigin-RevId: 468808165
parent f46d7b9d
...@@ -69,8 +69,9 @@ flags.DEFINE_string( ...@@ -69,8 +69,9 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
'quant_type', 'quant_type',
default=None, default=None,
help='Post training quantization type. Support `int8`, `int8_full`, ' help='Post training quantization type. Support `int8_fp32_fallback`, '
'`fp16`, and `default`. See ' '`int8_fp32_input_output`, `int8_full`, `fp16`, `qat`, '
'`qat_fp32_input_output`, and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization ' 'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.') 'for more details.')
flags.DEFINE_integer('calibration_steps', 500, flags.DEFINE_integer('calibration_steps', 500,
......
...@@ -107,16 +107,17 @@ def convert_tflite_model(saved_model_dir: str, ...@@ -107,16 +107,17 @@ def convert_tflite_model(saved_model_dir: str,
representative_dataset, representative_dataset,
params=params, params=params,
calibration_steps=calibration_steps) calibration_steps=calibration_steps)
if quant_type == 'int8_full': if quant_type in ('int8_full', 'int8_fp32_input_output'):
converter.target_spec.supported_ops = [ converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8 tf.lite.OpsSet.TFLITE_BUILTINS_INT8
] ]
if quant_type == 'int8_full':
converter.inference_input_type = tf.uint8 # or tf.int8 converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8 converter.inference_output_type = tf.uint8 # or tf.int8
elif quant_type == 'fp16': elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16] converter.target_spec.supported_types = [tf.float16]
elif quant_type == 'default': elif quant_type in ('default', 'qat_fp32_input_output'):
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
elif quant_type == 'qat': elif quant_type == 'qat':
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
......
...@@ -76,7 +76,14 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -76,7 +76,14 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['mobilenet_imagenet'], experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'])) quant_type=[
None,
'default',
'fp16',
'int8_fp32_fallback',
'int8_full',
'int8_fp32_input_output',
]))
def test_export_tflite_image_classification(self, experiment, quant_type): def test_export_tflite_image_classification(self, experiment, quant_type):
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
...@@ -105,7 +112,14 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -105,7 +112,14 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['retinanet_mobile_coco'], experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'])) quant_type=[
None,
'default',
'fp16',
'int8_fp32_fallback',
'int8_full',
'int8_fp32_input_output',
]))
def test_export_tflite_detection(self, experiment, quant_type): def test_export_tflite_detection(self, experiment, quant_type):
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
...@@ -138,7 +152,14 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -138,7 +152,14 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['mnv2_deeplabv3_pascal'], experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'])) quant_type=[
None,
'default',
'fp16',
'int8_fp32_fallback',
'int8_full',
'int8_fp32_input_output',
]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type): def test_export_tflite_semantic_segmentation(self, experiment, quant_type):
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment