Commit 3fb1e20f authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310487163
parent e9e6d17c
......@@ -193,7 +193,7 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
file_pattern, batch_size, max_length, max_io_parallelism, shuffle, repeat,
static_batch=False, num_replicas=1, ctx=None):
"""Create dataset where each item is a dict of "inputs" and "targets".
......@@ -201,7 +201,7 @@ def _read_and_batch_from_files(
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per global batch of examples.
max_length: Maximum number of tokens per example
num_parallel_calls: Number of cpu cores for parallel input processing.
max_io_parallelism: Max number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
......@@ -237,13 +237,13 @@ def _read_and_batch_from_files(
options.experimental_deterministic = False
dataset = dataset.interleave(
_load_records,
cycle_length=num_parallel_calls,
cycle_length=max_io_parallelism,
num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example,
num_parallel_calls=num_parallel_calls)
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
......@@ -289,7 +289,7 @@ def train_input_fn(params, ctx=None):
return _generate_synthetic_data(params)
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True,
params["max_io_parallelism"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"],
num_replicas=params["num_gpus"], ctx=ctx)
......@@ -301,7 +301,7 @@ def eval_input_fn(params, ctx=None):
return _generate_synthetic_data(params)
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1,
params["max_io_parallelism"], shuffle=False, repeat=1,
static_batch=params["static_batch"], num_replicas=params["num_gpus"],
ctx=ctx)
......
......@@ -148,7 +148,7 @@ class TransformerTask(object):
params["decode_batch_size"] = flags_obj.decode_batch_size
params["decode_max_length"] = flags_obj.decode_max_length
params["padded_decode"] = flags_obj.padded_decode
params["num_parallel_calls"] = (
params["max_io_parallelism"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
params["use_synthetic_data"] = flags_obj.use_synthetic_data
......@@ -239,7 +239,7 @@ class TransformerTask(object):
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
map_data_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if params["use_ctl"]:
train_ds_iterator = iter(train_ds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment