Commit 201d523a authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 398514419
parent 36ab0686
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.vision.beta import configs from official.vision.beta import configs
from official.vision.beta.tasks import image_classification as img_cls_task from official.vision.beta import tasks
def create_representative_dataset( def create_representative_dataset(
...@@ -39,7 +39,13 @@ def create_representative_dataset( ...@@ -39,7 +39,13 @@ def create_representative_dataset(
""" """
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
task = img_cls_task.ImageClassificationTask(params.task)
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
else: else:
raise ValueError('Task {} not supported.'.format(type(params.task))) raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model. # Ensure batch size is 1 for TFLite model.
......
...@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations ...@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.vision.beta.dataloaders import tfexample_utils from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import detection as detection_serving
from official.vision.beta.serving import export_tflite_lib from official.vision.beta.serving import export_tflite_lib
from official.vision.beta.serving import image_classification as image_classification_serving from official.vision.beta.serving import image_classification as image_classification_serving
from official.vision.beta.serving import semantic_segmentation as semantic_segmentation_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
experiment=['mobilenet_imagenet'], experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'], quant_type=[None, 'default', 'fp16', 'int8'],
input_image_size=[[224, 224]])) input_image_size=[[224, 224]]))
def test_export_tflite(self, experiment, quant_type, input_image_size): def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file params.task.train_data.input_path = self._test_tfrecord_file
...@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(tflite_model, bytes) self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[256, 256]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['seg_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment