"tests/vscode:/vscode.git/clone" did not exist on "0b8e29289dd97805f778922c98a13cd0700d3ab3"
Commit 534e86b3 authored by Eli Bixby's avatar Eli Bixby
Browse files

Fix incorrect SyncReplicasOptimizer usage

parent 6024579b
......@@ -153,18 +153,26 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr)
# Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate')
loss = tf.reduce_mean(tower_losses, name='loss')
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
chief_hooks = []
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
params.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
train_hooks = [logging_hook, examples_sec_hook]
if params.sync:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True)
chief_hooks.append(sync_replicas_hook)
sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
train_hooks.append(sync_replicas_hook)
# Create single grouped train op
train_op = [
......@@ -185,14 +193,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes'])
}
loss = tf.reduce_mean(tower_losses, name='loss')
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
training_chief_hooks=chief_hooks,
training_hooks=train_hooks,
eval_metric_ops=metrics)
return _resnet_model_fn
......@@ -336,32 +343,24 @@ def get_experiment_fn(data_dir,
train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
hooks = [logging_hook, examples_sec_hook]
if run_config.num_worker_replicas:
num_workers = run_config.num_worker_replicas + 1
else:
num_workers = 1
classifier = tf.estimator.Estimator(
model_fn=get_model_fn(num_gpus, variable_strategy,
run_config.num_worker_replicas or 1),
model_fn=get_model_fn(num_gpus, variable_strategy, num_workers),
config=run_config,
params=hparams)
# Create experiment.
experiment = tf.contrib.learn.Experiment(
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
# Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks)
return experiment
return _experiment_fn
......@@ -386,7 +385,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training),
run_config=config,
hparams=tf.contrib.training.HParams(**hparams))
hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment