Commit 534e86b3 authored by Eli Bixby's avatar Eli Bixby
Browse files

Fix incorrect SyncReplicasOptimizer usage

parent 6024579b
...@@ -153,18 +153,26 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -153,18 +153,26 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(), learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr) boundaries, staged_lr)
# Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate')
loss = tf.reduce_mean(tower_losses, name='loss')
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum) learning_rate=learning_rate, momentum=momentum)
chief_hooks = []
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
params.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
train_hooks = [logging_hook, examples_sec_hook]
if params.sync: if params.sync:
optimizer = tf.train.SyncReplicasOptimizer( optimizer = tf.train.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=num_workers) optimizer, replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True) sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
chief_hooks.append(sync_replicas_hook) train_hooks.append(sync_replicas_hook)
# Create single grouped train op # Create single grouped train op
train_op = [ train_op = [
...@@ -185,14 +193,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -185,14 +193,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
'accuracy': 'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes']) tf.metrics.accuracy(stacked_labels, predictions['classes'])
} }
loss = tf.reduce_mean(tower_losses, name='loss')
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, mode=mode,
predictions=predictions, predictions=predictions,
loss=loss, loss=loss,
train_op=train_op, train_op=train_op,
training_chief_hooks=chief_hooks, training_hooks=train_hooks,
eval_metric_ops=metrics) eval_metric_ops=metrics)
return _resnet_model_fn return _resnet_model_fn
...@@ -336,32 +343,24 @@ def get_experiment_fn(data_dir, ...@@ -336,32 +343,24 @@ def get_experiment_fn(data_dir,
train_steps = hparams.train_steps train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'} if run_config.num_worker_replicas:
num_workers = run_config.num_worker_replicas + 1
logging_hook = tf.train.LoggingTensorHook( else:
tensors=tensors_to_log, every_n_iter=100) num_workers = 1
hooks = [logging_hook, examples_sec_hook]
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=get_model_fn(num_gpus, variable_strategy, model_fn=get_model_fn(num_gpus, variable_strategy, num_workers),
run_config.num_worker_replicas or 1),
config=run_config, config=run_config,
params=hparams) params=hparams)
# Create experiment. # Create experiment.
experiment = tf.contrib.learn.Experiment( return tf.contrib.learn.Experiment(
classifier, classifier,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
train_steps=train_steps, train_steps=train_steps,
eval_steps=eval_steps) eval_steps=eval_steps)
# Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks)
return experiment
return _experiment_fn return _experiment_fn
...@@ -386,7 +385,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy, ...@@ -386,7 +385,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
get_experiment_fn(data_dir, num_gpus, variable_strategy, get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training), use_distortion_for_training),
run_config=config, run_config=config,
hparams=tf.contrib.training.HParams(**hparams)) hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment