Commit 1e48a60a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

ncf input pipeline change to remove the unbatch/rebatch when using offline generated data

PiperOrigin-RevId: 265508969
parent 32235d83
......@@ -21,7 +21,6 @@ from __future__ import print_function
import functools
# pylint: disable=g-bad-import-order
import numpy as np
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
......@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern,
def make_dataset(files_dataset, shard_index):
"""Returns dataset for sharded tf record files."""
if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch "
"size ({})".format(pre_batch_size, batch_size))
files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
dataset = files_dataset.interleave(tf.data.TFRecordDataset)
decode_fn = functools.partial(
......@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training=is_training)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.unbatch())
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
dataset = tf.data.Dataset.range(NUM_SHARDS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment