Commit e7027fec authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481251614
parent 6f521d91
......@@ -33,13 +33,14 @@ def write_small_dataset(examples: Sequence[Union[tf.train.Example,
examples: List of tf.train.Example or tf.train.SequenceExample.
output_path: Output path for the dataset.
file_type: A string indicating the file format, could be: 'tfrecord',
'tfrecord_compressed', 'riegeli'.
'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
string is case insensitive.
"""
file_type = file_type.lower()
if file_type == 'tfrecord':
if file_type == 'tfrecord' or file_type == 'tfrecords':
_write_tfrecord(examples, output_path)
elif file_type == 'tfrecord_compressed':
elif file_type == 'tfrecord_compressed' or file_type == 'tfrecords_gzip':
_write_tfrecord(examples, output_path,
tf.io.TFRecordOptions(compression_type='GZIP'))
elif file_type == 'riegeli':
......
......@@ -30,8 +30,9 @@ class FileWritersTest(tf.test.TestCase, parameterized.TestCase):
example_builder.add_bytes_feature('foo', 'Hello World!')
self._example = example_builder.example
@parameterized.parameters('tfrecord', 'TFRecord', 'tfrecord_compressed',
'TFRecord_Compressed')
@parameterized.parameters('tfrecord', 'TFRecord', 'tfrecords',
'tfrecord_compressed', 'TFRecord_Compressed',
'tfrecords_gzip')
def test_write_small_dataset_success(self, file_type):
temp_dir = self.create_tempdir()
temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment