Unverified Commit 04792078 authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Shard input for distribution strategy. (#6349)

* Shard input for distribution strategy.

* Pass in input_context from real input_fn.

* Pass in input_context from real input_fn.

* Make pipeline id base 1 for better readability.
parent 3024bde6
......@@ -159,9 +159,15 @@ def parse_record(raw_record, is_training, dtype):
return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1, parse_record_fn=parse_record):
def input_fn(is_training,
data_dir,
batch_size,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1,
parse_record_fn=parse_record,
input_context=None):
"""Input function which provides batches for train or eval.
Args:
......@@ -173,6 +179,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
Returns:
A dataset that can be used for iteration.
......@@ -180,6 +188,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
if input_context:
tf.compat.v1.logging.info('Sharding the dataset %d/%d' % (
(input_context.input_pipeline_id + 1),
input_context.num_input_pipelines))
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if is_training:
# Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
......
......@@ -600,7 +600,7 @@ def resnet_main(
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
def input_fn_train(num_epochs):
def input_fn_train(num_epochs, input_context=None):
return input_function(
is_training=True,
data_dir=flags_obj.data_dir,
......@@ -609,7 +609,8 @@ def resnet_main(
num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj),
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
num_parallel_batches=flags_obj.datasets_num_parallel_batches)
num_parallel_batches=flags_obj.datasets_num_parallel_batches,
input_context=input_context)
def input_fn_eval():
return input_function(
......@@ -624,10 +625,13 @@ def resnet_main(
flags_obj.train_epochs)
use_train_and_evaluate = flags_obj.use_train_and_evaluate or (
distribution_strategy.__class__.__name__ == 'CollectiveAllReduceStrategy')
distribution_strategy.__class__.__name__ in [
'CollectiveAllReduceStrategy', 'MultiWorkerMirroredStrategy'])
if use_train_and_evaluate:
train_spec = tf.estimator.TrainSpec(
input_fn=lambda: input_fn_train(train_epochs), hooks=train_hooks,
input_fn=lambda input_context=None: input_fn_train(
train_epochs, input_context=input_context),
hooks=train_hooks,
max_steps=flags_obj.max_train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval,
steps=flags_obj.max_train_steps)
......@@ -661,8 +665,11 @@ def resnet_main(
# value of num_train_epochs in the lambda function will not be changed
# before it is used. So it is safe to ignore the pylint error here
# pylint: disable=cell-var-from-loop
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)
classifier.train(
input_fn=lambda input_context=None: input_fn_train(
num_train_epochs, input_context=input_context),
hooks=train_hooks,
max_steps=flags_obj.max_train_steps)
# flags_obj.max_train_steps is generally associated with testing and
# profiling. As a result it is frequently called with synthetic data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment