Commit dd9a81c0 authored by Igor's avatar Igor Committed by GitHub
Browse files

Merge pull request #2343 from elibixby/fixsyncreplicausage

Fix incorrect SyncReplicasOptimizer usage
parents 78bb025c 28d37e7a
...@@ -153,18 +153,27 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -153,18 +153,27 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(), learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr) boundaries, staged_lr)
# Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate') loss = tf.reduce_mean(tower_losses, name='loss')
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
params.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
train_hooks = [logging_hook, examples_sec_hook]
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum) learning_rate=learning_rate, momentum=momentum)
chief_hooks = []
if params.sync: if params.sync:
optimizer = tf.train.SyncReplicasOptimizer( optimizer = tf.train.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=num_workers) optimizer, replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True) sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
chief_hooks.append(sync_replicas_hook) train_hooks.append(sync_replicas_hook)
# Create single grouped train op # Create single grouped train op
train_op = [ train_op = [
...@@ -185,14 +194,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -185,14 +194,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
'accuracy': 'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes']) tf.metrics.accuracy(stacked_labels, predictions['classes'])
} }
loss = tf.reduce_mean(tower_losses, name='loss')
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, mode=mode,
predictions=predictions, predictions=predictions,
loss=loss, loss=loss,
train_op=train_op, train_op=train_op,
training_chief_hooks=chief_hooks, training_hooks=train_hooks,
eval_metric_ops=metrics) eval_metric_ops=metrics)
return _resnet_model_fn return _resnet_model_fn
...@@ -336,16 +344,7 @@ def get_experiment_fn(data_dir, ...@@ -336,16 +344,7 @@ def get_experiment_fn(data_dir,
train_steps = hparams.train_steps train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
hooks = [logging_hook, examples_sec_hook]
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=get_model_fn(num_gpus, variable_strategy, model_fn=get_model_fn(num_gpus, variable_strategy,
run_config.num_worker_replicas or 1), run_config.num_worker_replicas or 1),
...@@ -353,15 +352,12 @@ def get_experiment_fn(data_dir, ...@@ -353,15 +352,12 @@ def get_experiment_fn(data_dir,
params=hparams) params=hparams)
# Create experiment. # Create experiment.
experiment = tf.contrib.learn.Experiment( return tf.contrib.learn.Experiment(
classifier, classifier,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
train_steps=train_steps, train_steps=train_steps,
eval_steps=eval_steps) eval_steps=eval_steps)
# Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks)
return experiment
return _experiment_fn return _experiment_fn
...@@ -386,7 +382,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy, ...@@ -386,7 +382,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
get_experiment_fn(data_dir, num_gpus, variable_strategy, get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training), use_distortion_for_training),
run_config=config, run_config=config,
hparams=tf.contrib.training.HParams(**hparams)) hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment