Commit 6a2de9bb authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Fix NCF input pipeline to avoid reading the same file multiple times in one epoch.

PiperOrigin-RevId: 322415899
parent f97e0231
...@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf ...@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import movielens
NUM_SHARDS = 16
def create_dataset_from_tf_record_files(input_file_pattern, def create_dataset_from_tf_record_files(input_file_pattern,
...@@ -36,32 +34,23 @@ def create_dataset_from_tf_record_files(input_file_pattern, ...@@ -36,32 +34,23 @@ def create_dataset_from_tf_record_files(input_file_pattern,
batch_size, batch_size,
is_training=True): is_training=True):
"""Creates dataset from (tf)records files for training/evaluation.""" """Creates dataset from (tf)records files for training/evaluation."""
if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch "
"size ({})".format(pre_batch_size, batch_size))
files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training) files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
def make_dataset(files_dataset, shard_index): dataset = files.interleave(
"""Returns dataset for sharded tf record files.""" tf.data.TFRecordDataset,
if pre_batch_size != batch_size: cycle_length=16,
raise ValueError("Pre-batch ({}) size is not equal to batch "
"size ({})".format(pre_batch_size, batch_size))
files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
dataset = files_dataset.interleave(
tf.data.TFRecordDataset,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
decode_fn = functools.partial(
data_pipeline.DatasetManager.deserialize,
batch_size=pre_batch_size,
is_training=is_training)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
dataset = tf.data.Dataset.range(NUM_SHARDS)
map_fn = functools.partial(make_dataset, files)
dataset = dataset.interleave(
map_fn,
cycle_length=NUM_SHARDS,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
decode_fn = functools.partial(
data_pipeline.DatasetManager.deserialize,
batch_size=pre_batch_size,
is_training=is_training)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment