"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "69467ea59003e950152e2abb9e447807c45cad79"
Commit ee832b66 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Adding adjust_learning_rate option

parent 2cc4b2e2
...@@ -73,23 +73,31 @@ tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.') ...@@ -73,23 +73,31 @@ tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
tf.flags.DEFINE_float('weight_decay', 2e-4, 'Weight decay for convolutions.') tf.flags.DEFINE_float('weight_decay', 2e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_float('adjust_learning_rate', 1,
"""This value will be multiplied by the learning rate.
By default the learning rate is
[0.1, 0.001, 0.0001, 0.00002]
""".)
tf.flags.DEFINE_boolean('use_distortion_for_training', True, tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.') 'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False, tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,' """If True will run an experiment,
'otherwise will run training and evaluation' otherwise will run training and evaluation
'using the estimator interface.' using the estimator interface.
'Experiments perform training on several workers in' Experiments perform training on several workers in
'parallel, in other words experiments know how to' parallel, in other words experiments know how to
' invoke train and eval in a sensible fashion for' invoke train and eval in a sensible fashion for
' distributed training.') distributed training.
""")
tf.flags.DEFINE_boolean('sync', False, tf.flags.DEFINE_boolean('sync', False,
'If true when running in a distributed environment' """If true when running in a distributed environment
'will run on sync mode') will run on sync mode.
""")
tf.flags.DEFINE_integer('num_workers', 1, 'Number of workers') tf.flags.DEFINE_integer('num_workers', 1, 'Number of workers.')
# Perf flags # Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1, tf.flags.DEFINE_integer('num_intra_threads', 1,
...@@ -308,7 +316,10 @@ def _resnet_model_fn(features, labels, mode): ...@@ -308,7 +316,10 @@ def _resnet_model_fn(features, labels, mode):
num_batches_per_epoch * x num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64) for x in np.array([82, 123, 300], dtype=np.int64)
] ]
staged_lr = [0.1, 0.01, 0.001, 0.0002] staged_lr = [
FLAGS.adjust_learning_rate * x
for x in [0.1, 0.01, 0.001, 0.0002]]
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(), learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr) boundaries, staged_lr)
# Create a nicely-named tensor for logging # Create a nicely-named tensor for logging
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment