Commit dd9a81c0 authored by Igor's avatar Igor Committed by GitHub
Browse files

Merge pull request #2343 from elibixby/fixsyncreplicausage

Fix incorrect SyncReplicasOptimizer usage
parents 78bb025c 28d37e7a
......@@ -153,18 +153,27 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr)
# Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate')
loss = tf.reduce_mean(tower_losses, name='loss')
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
params.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
train_hooks = [logging_hook, examples_sec_hook]
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
chief_hooks = []
if params.sync:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True)
chief_hooks.append(sync_replicas_hook)
sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
train_hooks.append(sync_replicas_hook)
# Create single grouped train op
train_op = [
......@@ -185,14 +194,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes'])
}
loss = tf.reduce_mean(tower_losses, name='loss')
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
training_chief_hooks=chief_hooks,
training_hooks=train_hooks,
eval_metric_ops=metrics)
return _resnet_model_fn
......@@ -336,15 +344,6 @@ def get_experiment_fn(data_dir,
train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
hooks = [logging_hook, examples_sec_hook]
classifier = tf.estimator.Estimator(
model_fn=get_model_fn(num_gpus, variable_strategy,
......@@ -353,15 +352,12 @@ def get_experiment_fn(data_dir,
params=hparams)
# Create experiment.
experiment = tf.contrib.learn.Experiment(
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
# Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks)
return experiment
return _experiment_fn
......@@ -386,7 +382,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training),
run_config=config,
hparams=tf.contrib.training.HParams(**hparams))
hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment