Unverified Commit 721cd512 authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Add `input_context` to `input_fn` in cifar10_main. (#6414)

* Add `input_context` to `input_fn` in cifar10_main.

* Change sharding log message to be consistent with `dataset.shard` params.

* Lint
parent 7b5606a5
......@@ -107,9 +107,15 @@ def preprocess_image(image, is_training):
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1, parse_record_fn=parse_record):
def input_fn(is_training,
data_dir,
batch_size,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1,
parse_record_fn=parse_record,
input_context=None):
"""Input function which provides batches for train or eval.
Args:
......@@ -121,6 +127,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
Returns:
A dataset that can be used for iteration.
......@@ -128,6 +136,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
if input_context:
tf.compat.v1.logging.info(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d' % (
input_context.input_pipeline_id, input_context.num_input_pipelines))
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return resnet_run_loop.process_record_dataset(
dataset=dataset,
is_training=is_training,
......
......@@ -189,9 +189,9 @@ def input_fn(is_training,
dataset = tf.data.Dataset.from_tensor_slices(filenames)
if input_context:
tf.compat.v1.logging.info('Sharding the dataset %d/%d' % (
(input_context.input_pipeline_id + 1),
input_context.num_input_pipelines))
tf.compat.v1.logging.info(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d' % (
input_context.input_pipeline_id, input_context.num_input_pipelines))
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment