Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
534e86b3
Commit
534e86b3
authored
Sep 05, 2017
by
Eli Bixby
Browse files
Fix incorrect SyncReplicasOptimizer usage
parent
6024579b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
22 deletions
+23
-22
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+23
-22
No files found.
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
534e86b3
...
@@ -153,18 +153,26 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -153,18 +153,26 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
learning_rate
=
tf
.
train
.
piecewise_constant
(
tf
.
train
.
get_global_step
(),
learning_rate
=
tf
.
train
.
piecewise_constant
(
tf
.
train
.
get_global_step
(),
boundaries
,
staged_lr
)
boundaries
,
staged_lr
)
# Create a nicely-named tensor for logging
learning_rate
=
tf
.
identity
(
learning_rate
,
name
=
'learning_rate'
)
loss
=
tf
.
reduce_mean
(
tower_losses
,
name
=
'loss'
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
momentum
)
learning_rate
=
learning_rate
,
momentum
=
momentum
)
chief_hooks
=
[]
examples_sec_hook
=
cifar10_utils
.
ExamplesPerSecondHook
(
params
.
train_batch_size
,
every_n_steps
=
10
)
tensors_to_log
=
{
'learning_rate'
:
learning_rate
,
'loss'
:
loss
}
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
train_hooks
=
[
logging_hook
,
examples_sec_hook
]
if
params
.
sync
:
if
params
.
sync
:
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
optimizer
,
replicas_to_aggregate
=
num_workers
)
optimizer
,
replicas_to_aggregate
=
num_workers
)
sync_replicas_hook
=
optimizer
.
make_session_run_hook
(
True
)
sync_replicas_hook
=
optimizer
.
make_session_run_hook
(
params
.
is_chief
)
chief
_hooks
.
append
(
sync_replicas_hook
)
train
_hooks
.
append
(
sync_replicas_hook
)
# Create single grouped train op
# Create single grouped train op
train_op
=
[
train_op
=
[
...
@@ -185,14 +193,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -185,14 +193,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
'accuracy'
:
'accuracy'
:
tf
.
metrics
.
accuracy
(
stacked_labels
,
predictions
[
'classes'
])
tf
.
metrics
.
accuracy
(
stacked_labels
,
predictions
[
'classes'
])
}
}
loss
=
tf
.
reduce_mean
(
tower_losses
,
name
=
'loss'
)
return
tf
.
estimator
.
EstimatorSpec
(
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
predictions
=
predictions
,
predictions
=
predictions
,
loss
=
loss
,
loss
=
loss
,
train_op
=
train_op
,
train_op
=
train_op
,
training_
chief_
hooks
=
chief
_hooks
,
training_hooks
=
train
_hooks
,
eval_metric_ops
=
metrics
)
eval_metric_ops
=
metrics
)
return
_resnet_model_fn
return
_resnet_model_fn
...
@@ -336,32 +343,24 @@ def get_experiment_fn(data_dir,
...
@@ -336,32 +343,24 @@ def get_experiment_fn(data_dir,
train_steps
=
hparams
.
train_steps
train_steps
=
hparams
.
train_steps
eval_steps
=
num_eval_examples
//
hparams
.
eval_batch_size
eval_steps
=
num_eval_examples
//
hparams
.
eval_batch_size
examples_sec_hook
=
cifar10_utils
.
ExamplesPerSecondHook
(
hparams
.
train_batch_size
,
every_n_steps
=
10
)
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'loss'
:
'loss'
}
if
run_config
.
num_worker_replicas
:
num_workers
=
run_config
.
num_worker_replicas
+
1
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
else
:
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
num_workers
=
1
hooks
=
[
logging_hook
,
examples_sec_hook
]
classifier
=
tf
.
estimator
.
Estimator
(
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
get_model_fn
(
num_gpus
,
variable_strategy
,
model_fn
=
get_model_fn
(
num_gpus
,
variable_strategy
,
num_workers
),
run_config
.
num_worker_replicas
or
1
),
config
=
run_config
,
config
=
run_config
,
params
=
hparams
)
params
=
hparams
)
# Create experiment.
# Create experiment.
experiment
=
tf
.
contrib
.
learn
.
Experiment
(
return
tf
.
contrib
.
learn
.
Experiment
(
classifier
,
classifier
,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
eval_input_fn
,
eval_input_fn
=
eval_input_fn
,
train_steps
=
train_steps
,
train_steps
=
train_steps
,
eval_steps
=
eval_steps
)
eval_steps
=
eval_steps
)
# Adding hooks to be used by the estimator on training modes
experiment
.
extend_train_hooks
(
hooks
)
return
experiment
return
_experiment_fn
return
_experiment_fn
...
@@ -386,7 +385,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
...
@@ -386,7 +385,9 @@ def main(job_dir, data_dir, num_gpus, variable_strategy,
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
use_distortion_for_training
),
use_distortion_for_training
),
run_config
=
config
,
run_config
=
config
,
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
hparams
))
hparams
=
tf
.
contrib
.
training
.
HParams
(
is_chief
=
config
.
is_chief
,
**
hparams
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment