Commit d0269919 authored by Eli Bixby's avatar Eli Bixby
Browse files

Move optimizer closer to usage

parent 534e86b3
......@@ -155,9 +155,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
boundaries, staged_lr)
loss = tf.reduce_mean(tower_losses, name='loss')
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
params.train_batch_size, every_n_steps=10)
......@@ -168,6 +165,10 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
tensors=tensors_to_log, every_n_iter=100)
train_hooks = [logging_hook, examples_sec_hook]
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
if params.sync:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=num_workers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment