".github/vscode:/vscode.git/clone" did not exist on "cbe2adc78ab7a34fafebbf4c83582d6c29a461ed"
Commit f46d7b9d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 468798864
parent ef4f89e3
......@@ -30,6 +30,39 @@ from official.vision.serving import semantic_segmentation as semantic_segmentati
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Create test data for image classification.
self.test_tfrecord_file_cls = os.path.join(self.get_temp_dir(),
'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=224, image_width=224))
self._create_test_tfrecord(
tfrecord_file=self.test_tfrecord_file_cls,
example=example,
num_samples=10)
# Create test data for object detection.
self.test_tfrecord_file_det = os.path.join(self.get_temp_dir(),
'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=128, image_width=128, image_channel=3, num_instances=10)
self._create_test_tfrecord(
tfrecord_file=self.test_tfrecord_file_det,
example=example,
num_samples=10)
# Create test data for semantic segmentation.
self.test_tfrecord_file_seg = os.path.join(self.get_temp_dir(),
'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=512, image_width=512, image_channel=3)
self._create_test_tfrecord(
tfrecord_file=self.test_tfrecord_file_seg,
example=example,
num_samples=10)
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
......@@ -43,24 +76,18 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[224, 224]]))
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full']))
def test_export_tflite_image_classification(self, experiment, quant_type):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
params.task.validation_data.input_path = self.test_tfrecord_file_cls
params.task.train_data.input_path = self.test_tfrecord_file_cls
params.task.train_data.shuffle_buffer_size = 10
temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_image_size=[224, 224],
input_type='tflite')
self._export_from_module(
module=module,
......@@ -78,26 +105,22 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[384, 384]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full']))
def test_export_tflite_detection(self, experiment, quant_type):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
params.task.validation_data.input_path = self.test_tfrecord_file_det
params.task.train_data.input_path = self.test_tfrecord_file_det
params.task.model.num_classes = 2
params.task.model.backbone.spinenet_mobile.model_id = '49XS'
params.task.model.input_size = [128, 128, 3]
params.task.model.detection_generator.nms_version = 'v1'
params.task.train_data.shuffle_buffer_size = 5
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_image_size=[128, 128],
input_type='tflite')
self._export_from_module(
module=module,
......@@ -108,32 +131,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
calibration_steps=1)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full']))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
params.task.validation_data.input_path = self.test_tfrecord_file_seg
params.task.train_data.input_path = self.test_tfrecord_file_seg
params.task.train_data.shuffle_buffer_size = 10
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_image_size=[512, 512],
input_type='tflite')
self._export_from_module(
module=module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment