"tests/vscode:/vscode.git/clone" did not exist on "00add9f2511dffb1c4bccc7a9eeff836c8143bbf"
Commit dc91f48b authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Support reading compressed tfrecord files.

PiperOrigin-RevId: 436868704
parent 204dd7ec
...@@ -28,7 +28,8 @@ ...@@ -28,7 +28,8 @@
# ============================================================================== # ==============================================================================
"""Utility library for picking an appropriate dataset function.""" """Utility library for picking an appropriate dataset function."""
from typing import Any, Callable, Union, Type import functools
from typing import Any, Callable, Type, Union
import tensorflow as tf import tensorflow as tf
...@@ -38,5 +39,6 @@ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]] ...@@ -38,5 +39,6 @@ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
def pick_dataset_fn(file_type: str) -> PossibleDatasetType: def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
if file_type == 'tfrecord': if file_type == 'tfrecord':
return tf.data.TFRecordDataset return tf.data.TFRecordDataset
if file_type == 'tfrecord_compressed':
return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
raise ValueError('Unrecognized file_type: {}'.format(file_type)) raise ValueError('Unrecognized file_type: {}'.format(file_type))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment