Unverified Commit ad386df5 authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by GitHub
Browse files

Pass datasets_num_private_threads flag into Keras resnet model. (#6211)

parent 078575a1
...@@ -116,17 +116,20 @@ def run(flags_obj): ...@@ -116,17 +116,20 @@ def run(flags_obj):
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(is_training=True, train_input_dataset = input_fn(
data_dir=flags_obj.data_dir, is_training=True,
batch_size=flags_obj.batch_size, data_dir=flags_obj.data_dir,
num_epochs=flags_obj.train_epochs, batch_size=flags_obj.batch_size,
parse_record_fn=parse_record_keras) num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
eval_input_dataset = input_fn(is_training=False, datasets_num_private_threads=flags_obj.datasets_num_private_threads)
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, eval_input_dataset = input_fn(
num_epochs=flags_obj.train_epochs, is_training=False,
parse_record_fn=parse_record_keras) data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
......
...@@ -77,7 +77,6 @@ def process_record_dataset(dataset, ...@@ -77,7 +77,6 @@ def process_record_dataset(dataset,
# Defines a specific size thread pool for tf.data operations. # Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads: if datasets_num_private_threads:
options = tf.data.Options() options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = ( options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads) datasets_num_private_threads)
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment