Commit 33f0aa0f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 394116606
parent a9d416f1
......@@ -13,10 +13,10 @@
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import functools
from typing import Iterator, List, Optional
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
......@@ -45,6 +45,7 @@ def create_representative_dataset(
# Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32'
logging.info('Task config: %s', params.task.as_dict())
return task.build_inputs(params=params.task.train_data)
......
......@@ -54,6 +54,7 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def test_export_tflite(self, experiment, quant_type, input_image_size):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file
temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule(
params=params, batch_size=1, input_image_size=input_image_size)
......@@ -66,7 +67,7 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=20)
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment