Commit f46d7b9d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 468798864
parent ef4f89e3
...@@ -30,6 +30,39 @@ from official.vision.serving import semantic_segmentation as semantic_segmentati ...@@ -30,6 +30,39 @@ from official.vision.serving import semantic_segmentation as semantic_segmentati
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Create test data for image classification.
self.test_tfrecord_file_cls = os.path.join(self.get_temp_dir(),
'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=224, image_width=224))
self._create_test_tfrecord(
tfrecord_file=self.test_tfrecord_file_cls,
example=example,
num_samples=10)
# Create test data for object detection.
self.test_tfrecord_file_det = os.path.join(self.get_temp_dir(),
'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=128, image_width=128, image_channel=3, num_instances=10)
self._create_test_tfrecord(
tfrecord_file=self.test_tfrecord_file_det,
example=example,
num_samples=10)
# Create test data for semantic segmentation.
self.test_tfrecord_file_seg = os.path.join(self.get_temp_dir(),
'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=512, image_width=512, image_channel=3)
self._create_test_tfrecord(
tfrecord_file=self.test_tfrecord_file_seg,
example=example,
num_samples=10)
def _create_test_tfrecord(self, tfrecord_file, example, num_samples): def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord( tfexample_utils.dump_to_tfrecord(
...@@ -43,24 +76,18 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -43,24 +76,18 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['mobilenet_imagenet'], experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'], quant_type=[None, 'default', 'fp16', 'int8', 'int8_full']))
input_image_size=[[224, 224]])) def test_export_tflite_image_classification(self, experiment, quant_type):
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file params.task.validation_data.input_path = self.test_tfrecord_file_cls
params.task.train_data.input_path = test_tfrecord_file params.task.train_data.input_path = self.test_tfrecord_file_cls
params.task.train_data.shuffle_buffer_size = 10
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule( module = image_classification_serving.ClassificationModule(
params=params, params=params,
batch_size=1, batch_size=1,
input_image_size=input_image_size, input_image_size=[224, 224],
input_type='tflite') input_type='tflite')
self._export_from_module( self._export_from_module(
module=module, module=module,
...@@ -78,26 +105,22 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -78,26 +105,22 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['retinanet_mobile_coco'], experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'], quant_type=[None, 'default', 'fp16', 'int8', 'int8_full']))
input_image_size=[[384, 384]])) def test_export_tflite_detection(self, experiment, quant_type):
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file params.task.validation_data.input_path = self.test_tfrecord_file_det
params.task.train_data.input_path = test_tfrecord_file params.task.train_data.input_path = self.test_tfrecord_file_det
params.task.model.num_classes = 2
params.task.model.backbone.spinenet_mobile.model_id = '49XS'
params.task.model.input_size = [128, 128, 3]
params.task.model.detection_generator.nms_version = 'v1'
params.task.train_data.shuffle_buffer_size = 5
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule( module = detection_serving.DetectionModule(
params=params, params=params,
batch_size=1, batch_size=1,
input_image_size=input_image_size, input_image_size=[128, 128],
input_type='tflite') input_type='tflite')
self._export_from_module( self._export_from_module(
module=module, module=module,
...@@ -108,32 +131,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -108,32 +131,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
saved_model_dir=os.path.join(temp_dir, 'saved_model'), saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type, quant_type=quant_type,
params=params, params=params,
calibration_steps=5) calibration_steps=1)
self.assertIsInstance(tflite_model, bytes) self.assertIsInstance(tflite_model, bytes)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['mnv2_deeplabv3_pascal'], experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'], quant_type=[None, 'default', 'fp16', 'int8', 'int8_full']))
input_image_size=[[512, 512]])) def test_export_tflite_semantic_segmentation(self, experiment, quant_type):
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file params.task.validation_data.input_path = self.test_tfrecord_file_seg
params.task.train_data.input_path = test_tfrecord_file params.task.train_data.input_path = self.test_tfrecord_file_seg
params.task.train_data.shuffle_buffer_size = 10
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule( module = semantic_segmentation_serving.SegmentationModule(
params=params, params=params,
batch_size=1, batch_size=1,
input_image_size=input_image_size, input_image_size=[512, 512],
input_type='tflite') input_type='tflite')
self._export_from_module( self._export_from_module(
module=module, module=module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment