Commit 201d523a authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 398514419
parent 36ab0686
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision.beta import configs
from official.vision.beta.tasks import image_classification as img_cls_task
from official.vision.beta import tasks
def create_representative_dataset(
......@@ -39,7 +39,13 @@ def create_representative_dataset(
"""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
task = img_cls_task.ImageClassificationTask(params.task)
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
else:
raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model.
......
......@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import detection as detection_serving
from official.vision.beta.serving import export_tflite_lib
from official.vision.beta.serving import image_classification as image_classification_serving
from official.vision.beta.serving import semantic_segmentation as semantic_segmentation_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
......@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'],
input_image_size=[[224, 224]]))
def test_export_tflite(self, experiment, quant_type, input_image_size):
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file
......@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[256, 256]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['seg_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment