Commit ce445976 authored by Niranjan Hasabnis's avatar Niranjan Hasabnis Committed by Karmel Allison
Browse files

Support to set inter_op_parallelism_threads and intra_op_parallelism_threads for ResNet (#3502)

* Adding support to set inter_op_parallelism_threads and intra_op_parallelism_threads for official/resnet

* Addressing review comments
parent ea7481c8
......@@ -584,8 +584,15 @@ def resnet_main(flags, model_function, input_function):
model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
# Create session config based on values of inter_op_parallelism_threads and
# intra_op_parallelism_threads.
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags.intra_op_parallelism_threads)
# Set up a RunConfig to save checkpoint and set session config.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9,
session_config=session_config)
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={
......@@ -675,3 +682,13 @@ class ResnetArgParser(argparse.ArgumentParser):
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs. Note that this is '
'superseded by the --num_gpus flag.')
self.add_argument(
'--inter_op_parallelism_threads', type=int, default=0,
help='Number of inter_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
self.add_argument(
'--intra_op_parallelism_threads', type=int, default=0,
help='Number of intra_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment