"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ad8068e4143c9984181e3dae877505e0fe4d147d"
Commit 866176ea authored by André Susano Pinto's avatar André Susano Pinto Committed by A. Unique TensorFlower
Browse files

When not specified by user, make steps_per_loop be dependent on the devices available.

A default of 1 all the times is bad for TPU users which end up not using the device effectively.
A larger default all the times is bad for GPU users.

So compromise and make this dependent on the devices available.

PiperOrigin-RevId: 312230371
parent c4bf6d3e
...@@ -49,10 +49,11 @@ def define_common_bert_flags(): ...@@ -49,10 +49,11 @@ def define_common_bert_flags():
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer( flags.DEFINE_integer(
'steps_per_loop', 1, 'steps_per_loop', None,
'Number of steps per graph-mode loop. Only training step ' 'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called ' 'happens inside the loop. Callbacks will not be called '
'inside.') 'inside. If not set the value will be configured depending on the '
'devices available.')
flags.DEFINE_float('learning_rate', 5e-5, flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.') 'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0, flags.DEFINE_float('end_lr', 0.0,
......
...@@ -111,7 +111,7 @@ def run_customized_training_loop( ...@@ -111,7 +111,7 @@ def run_customized_training_loop(
model_dir=None, model_dir=None,
train_input_fn=None, train_input_fn=None,
steps_per_epoch=None, steps_per_epoch=None,
steps_per_loop=1, steps_per_loop=None,
epochs=1, epochs=1,
eval_input_fn=None, eval_input_fn=None,
eval_steps=None, eval_steps=None,
...@@ -210,10 +210,19 @@ def run_customized_training_loop( ...@@ -210,10 +210,19 @@ def run_customized_training_loop(
] ]
if [arg for arg in required_arguments if arg is None]: if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, ' raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'`steps_per_loop` and `steps_per_epoch` are required ' '`steps_per_epoch` and `train_input_fn` are required '
'parameters.') 'parameters.')
if not steps_per_loop:
if tf.config.list_logical_devices('TPU'):
# One can't fully utilize a TPU with steps_per_loop=1, so in this case
# default users to a more useful value.
steps_per_loop = min(1000, steps_per_epoch)
else:
steps_per_loop = 1
logging.info('steps_per_loop not specified. Using steps_per_loop=%d',
steps_per_loop)
if steps_per_loop > steps_per_epoch: if steps_per_loop > steps_per_epoch:
logging.error( logging.warning(
'steps_per_loop: %d is specified to be greater than ' 'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as' ' steps_per_epoch: %d, we will use steps_per_epoch as'
' steps_per_loop.', steps_per_loop, steps_per_epoch) ' steps_per_loop.', steps_per_loop, steps_per_epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment