Commit be205db2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480706898
parent 3347c155
...@@ -189,8 +189,7 @@ def representative_dataset_gen(export_config): ...@@ -189,8 +189,7 @@ def representative_dataset_gen(export_config):
"""Gets a python generator of numpy arrays for the given dataset.""" """Gets a python generator of numpy arrays for the given dataset."""
quantization_config = export_config.quantization_config quantization_config = export_config.quantization_config
dataset = tfds.builder( dataset = tfds.builder(
quantization_config.dataset_name, quantization_config.dataset_name, try_gcs=True)
data_dir=quantization_config.dataset_dir)
dataset.download_and_prepare() dataset.download_and_prepare()
data = dataset.as_dataset()[quantization_config.dataset_split] data = dataset.as_dataset()[quantization_config.dataset_split]
iterator = data.as_numpy_iterator() iterator = data.as_numpy_iterator()
...@@ -207,7 +206,8 @@ def configure_tflite_converter(export_config, converter): ...@@ -207,7 +206,8 @@ def configure_tflite_converter(export_config, converter):
"""Common code for picking up quantization parameters.""" """Common code for picking up quantization parameters."""
quantization_config = export_config.quantization_config quantization_config = export_config.quantization_config
if quantization_config.quantize: if quantization_config.quantize:
if quantization_config.dataset_dir is None: if (quantization_config.dataset_dir is
None) and (quantization_config.dataset_name is None):
raise ValueError( raise ValueError(
'Must provide a representative dataset when quantizing the model.') 'Must provide a representative dataset when quantizing the model.')
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment