".github/vscode:/vscode.git/clone" did not exist on "9ab65d2c35db13fd579923f66db85e3ba738e215"
Commit 33f0aa0f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 394116606
parent a9d416f1
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
"""Library to facilitate TFLite model conversion.""" """Library to facilitate TFLite model conversion."""
import functools import functools
from typing import Iterator, List, Optional from typing import Iterator, List, Optional
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
...@@ -45,6 +45,7 @@ def create_representative_dataset( ...@@ -45,6 +45,7 @@ def create_representative_dataset(
# Ensure batch size is 1 for TFLite model. # Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1 params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32' params.task.train_data.dtype = 'float32'
logging.info('Task config: %s', params.task.as_dict())
return task.build_inputs(params=params.task.train_data) return task.build_inputs(params=params.task.train_data)
......
...@@ -54,6 +54,7 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -54,6 +54,7 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def test_export_tflite(self, experiment, quant_type, input_image_size): def test_export_tflite(self, experiment, quant_type, input_image_size):
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule( module = image_classification_serving.ClassificationModule(
params=params, batch_size=1, input_image_size=input_image_size) params=params, batch_size=1, input_image_size=input_image_size)
...@@ -66,7 +67,7 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,7 +67,7 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
saved_model_dir=os.path.join(temp_dir, 'saved_model'), saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type, quant_type=quant_type,
params=params, params=params,
calibration_steps=20) calibration_steps=5)
self.assertIsInstance(tflite_model, bytes) self.assertIsInstance(tflite_model, bytes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment