Unverified Commit ad386df5 authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

Pass datasets_num_private_threads flag into Keras resnet model. (#6211)

parent 078575a1
......@@ -116,17 +116,20 @@ def run(flags_obj):
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......
......@@ -77,7 +77,6 @@ def process_record_dataset(dataset,
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads)
dataset = dataset.with_options(options)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment