Commit 809eaba9 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Support reading compressed tfrecord files.

PiperOrigin-RevId: 436868704
parent d479abde
...@@ -28,7 +28,8 @@ ...@@ -28,7 +28,8 @@
# ============================================================================== # ==============================================================================
"""Utility library for picking an appropriate dataset function.""" """Utility library for picking an appropriate dataset function."""
from typing import Any, Callable, Union, Type import functools
from typing import Any, Callable, Type, Union
import tensorflow as tf import tensorflow as tf
...@@ -38,5 +39,6 @@ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]] ...@@ -38,5 +39,6 @@ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
def pick_dataset_fn(file_type: str) -> PossibleDatasetType: def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
if file_type == 'tfrecord': if file_type == 'tfrecord':
return tf.data.TFRecordDataset return tf.data.TFRecordDataset
if file_type == 'tfrecord_compressed':
return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
raise ValueError('Unrecognized file_type: {}'.format(file_type)) raise ValueError('Unrecognized file_type: {}'.format(file_type))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment