Commit 9b2f34bf authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 469600398
parent 4862f9d3
...@@ -69,9 +69,9 @@ flags.DEFINE_string( ...@@ -69,9 +69,9 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
'quant_type', 'quant_type',
default=None, default=None,
help='Post training quantization type. Support `int8_fp32_fallback`, ' help='Post training quantization type. Support `int8_fallback`, '
'`int8_fp32_input_output`, `int8_full`, `fp16`, `qat`, ' '`int8_full_fp32_io`, `int8_full`, `fp16`, `qat`, `qat_fp32_io`, '
'`qat_fp32_input_output`, and `default`. See ' '`int8_full_int8_io` and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization ' 'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.') 'for more details.')
flags.DEFINE_integer('calibration_steps', 500, flags.DEFINE_integer('calibration_steps', 500,
......
...@@ -107,17 +107,20 @@ def convert_tflite_model(saved_model_dir: str, ...@@ -107,17 +107,20 @@ def convert_tflite_model(saved_model_dir: str,
representative_dataset, representative_dataset,
params=params, params=params,
calibration_steps=calibration_steps) calibration_steps=calibration_steps)
if quant_type in ('int8_full', 'int8_fp32_input_output'): if quant_type.startswith('int8_full'):
converter.target_spec.supported_ops = [ converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8 tf.lite.OpsSet.TFLITE_BUILTINS_INT8
] ]
if quant_type == 'int8_full': if quant_type == 'int8_full':
converter.inference_input_type = tf.uint8 # or tf.int8 converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8 # or tf.int8 converter.inference_output_type = tf.uint8
if quant_type == 'int8_full_int8_io':
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
elif quant_type == 'fp16': elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16] converter.target_spec.supported_types = [tf.float16]
elif quant_type in ('default', 'qat_fp32_input_output'): elif quant_type in ('default', 'qat_fp32_io'):
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
elif quant_type == 'qat': elif quant_type == 'qat':
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
......
...@@ -80,9 +80,10 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -80,9 +80,10 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
None, None,
'default', 'default',
'fp16', 'fp16',
'int8_fp32_fallback', 'int8_fallback',
'int8_full', 'int8_full',
'int8_fp32_input_output', 'int8_full_fp32_io',
'int8_full_int8_io',
])) ]))
def test_export_tflite_image_classification(self, experiment, quant_type): def test_export_tflite_image_classification(self, experiment, quant_type):
...@@ -116,9 +117,10 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -116,9 +117,10 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
None, None,
'default', 'default',
'fp16', 'fp16',
'int8_fp32_fallback', 'int8_fallback',
'int8_full', 'int8_full',
'int8_fp32_input_output', 'int8_full_fp32_io',
'int8_full_int8_io',
])) ]))
def test_export_tflite_detection(self, experiment, quant_type): def test_export_tflite_detection(self, experiment, quant_type):
...@@ -156,9 +158,10 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -156,9 +158,10 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
None, None,
'default', 'default',
'fp16', 'fp16',
'int8_fp32_fallback', 'int8_fallback',
'int8_full', 'int8_full',
'int8_fp32_input_output', 'int8_full_fp32_io',
'int8_full_int8_io',
])) ]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type): def test_export_tflite_semantic_segmentation(self, experiment, quant_type):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment