Commit 3fb1e20f authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310487163
parent e9e6d17c
...@@ -193,7 +193,7 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -193,7 +193,7 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files( def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat, file_pattern, batch_size, max_length, max_io_parallelism, shuffle, repeat,
static_batch=False, num_replicas=1, ctx=None): static_batch=False, num_replicas=1, ctx=None):
"""Create dataset where each item is a dict of "inputs" and "targets". """Create dataset where each item is a dict of "inputs" and "targets".
...@@ -201,7 +201,7 @@ def _read_and_batch_from_files( ...@@ -201,7 +201,7 @@ def _read_and_batch_from_files(
file_pattern: String used to match the input TFRecord files. file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per global batch of examples. batch_size: Maximum number of tokens per global batch of examples.
max_length: Maximum number of tokens per example max_length: Maximum number of tokens per example
num_parallel_calls: Number of cpu cores for parallel input processing. max_io_parallelism: Max number of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements. shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever. repeated forever.
...@@ -237,13 +237,13 @@ def _read_and_batch_from_files( ...@@ -237,13 +237,13 @@ def _read_and_batch_from_files(
options.experimental_deterministic = False options.experimental_deterministic = False
dataset = dataset.interleave( dataset = dataset.interleave(
_load_records, _load_records,
cycle_length=num_parallel_calls, cycle_length=max_io_parallelism,
num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options) num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
# Parse each tf.Example into a dictionary # Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization. # TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example, dataset = dataset.map(_parse_example,
num_parallel_calls=num_parallel_calls) num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Remove examples where the input or target length exceeds the maximum length, # Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length)) dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
...@@ -289,7 +289,7 @@ def train_input_fn(params, ctx=None): ...@@ -289,7 +289,7 @@ def train_input_fn(params, ctx=None):
return _generate_synthetic_data(params) return _generate_synthetic_data(params)
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"], file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True, params["max_io_parallelism"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"], repeat=params["repeat_dataset"], static_batch=params["static_batch"],
num_replicas=params["num_gpus"], ctx=ctx) num_replicas=params["num_gpus"], ctx=ctx)
...@@ -301,7 +301,7 @@ def eval_input_fn(params, ctx=None): ...@@ -301,7 +301,7 @@ def eval_input_fn(params, ctx=None):
return _generate_synthetic_data(params) return _generate_synthetic_data(params)
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"], file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1, params["max_io_parallelism"], shuffle=False, repeat=1,
static_batch=params["static_batch"], num_replicas=params["num_gpus"], static_batch=params["static_batch"], num_replicas=params["num_gpus"],
ctx=ctx) ctx=ctx)
......
...@@ -148,7 +148,7 @@ class TransformerTask(object): ...@@ -148,7 +148,7 @@ class TransformerTask(object):
params["decode_batch_size"] = flags_obj.decode_batch_size params["decode_batch_size"] = flags_obj.decode_batch_size
params["decode_max_length"] = flags_obj.decode_max_length params["decode_max_length"] = flags_obj.decode_max_length
params["padded_decode"] = flags_obj.padded_decode params["padded_decode"] = flags_obj.padded_decode
params["num_parallel_calls"] = ( params["max_io_parallelism"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
params["use_synthetic_data"] = flags_obj.use_synthetic_data params["use_synthetic_data"] = flags_obj.use_synthetic_data
...@@ -239,7 +239,7 @@ class TransformerTask(object): ...@@ -239,7 +239,7 @@ class TransformerTask(object):
train_ds = data_pipeline.train_input_fn(params) train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map( train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"]) map_data_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if params["use_ctl"]: if params["use_ctl"]:
train_ds_iterator = iter(train_ds) train_ds_iterator = iter(train_ds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment