Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2f69dc64
Commit
2f69dc64
authored
Jul 28, 2017
by
Marianne Linhares Monteiro
Committed by
GitHub
Jul 28, 2017
Browse files
Small style fixes
parent
68218034
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
14 deletions
+26
-14
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+26
-14
No files found.
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
2f69dc64
...
...
@@ -33,6 +33,8 @@ import functools
import
operator
import
os
import
cifar10
import
cifar10_model
import
numpy
as
np
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
...
...
@@ -41,8 +43,6 @@ from tensorflow.python.training import basic_session_run_hooks
from
tensorflow.python.training
import
session_run_hook
from
tensorflow.python.training
import
training_util
import
cifar10
import
cifar10_model
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
...
...
@@ -80,9 +80,10 @@ tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,'
'otherwise will run training and evaluation'
'using the estimator interface.'
'Experiments perform training on several workers in parallel'
', in other words experiments know how to invoke train and'
' eval in a sensible fashion for distributed training.'
)
'Experiments perform training on several workers in'
'parallel, in other words experiments know how to'
' invoke train and eval in a sensible fashion for'
' distributed training.'
)
tf
.
flags
.
DEFINE_boolean
(
'sync'
,
False
,
'If true when running in a distributed environment'
...
...
@@ -117,8 +118,8 @@ tf.flags.DEFINE_boolean('log_device_placement', False,
class
ExamplesPerSecondHook
(
session_run_hook
.
SessionRunHook
):
"""Hook to print out examples per second
"""Hook to print out examples per second
.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
...
...
@@ -131,15 +132,16 @@ class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
every_n_steps
=
100
,
every_n_secs
=
None
,):
"""Initializer for ExamplesPerSecondHook.
Args:
Args:
batch_size: Total batch size used to calculate examples/second from
global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds.
"""
"""
if
(
every_n_steps
is
None
)
==
(
every_n_secs
is
None
):
raise
ValueError
(
'exactly one of every_n_steps
and every_n_secs should be provided.'
)
raise
ValueError
(
'exactly one of every_n_steps'
'
and every_n_secs should be provided.'
)
self
.
_timer
=
basic_session_run_hooks
.
SecondOrStepTimer
(
every_steps
=
every_n_steps
,
every_secs
=
every_n_secs
)
...
...
@@ -188,6 +190,7 @@ class GpuParamServerDeviceSetter(object):
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""Initializer for GpuParamServerDeviceSetter.
Args:
worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable Ops. Each variable is
...
...
@@ -202,7 +205,7 @@ class GpuParamServerDeviceSetter(object):
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
,
'VarHandleOp'
]:
return
self
.
worker_device
# Gets the least loaded ps_device
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
device_name
=
self
.
ps_devices
[
device_index
]
...
...
@@ -211,6 +214,7 @@ class GpuParamServerDeviceSetter(object):
return
device_name
def
_create_device_setter
(
is_cpu_ps
,
worker
,
num_gpus
):
"""Create device setter object."""
if
is_cpu_ps
:
...
...
@@ -400,7 +404,8 @@ def input_fn(subset, num_shards):
elif
subset
==
'validate'
or
subset
==
'eval'
:
batch_size
=
FLAGS
.
eval_batch_size
else
:
raise
ValueError
(
'Subset must be one of
\'
train
\'
,
\'
validate
\'
and
\'
eval
\'
'
)
raise
ValueError
(
'Subset must be one of
\'
train
\'
'
',
\'
validate
\'
and
\'
eval
\'
'
)
with
tf
.
device
(
'/cpu:0'
):
use_distortion
=
subset
==
'train'
and
FLAGS
.
use_distortion_for_training
dataset
=
cifar10
.
Cifar10DataSet
(
FLAGS
.
data_dir
,
subset
,
use_distortion
)
...
...
@@ -429,7 +434,14 @@ def input_fn(subset, num_shards):
# create experiment
def
get_experiment_fn
(
train_input_fn
,
eval_input_fn
,
train_steps
,
eval_steps
,
train_hooks
):
"""Returns an Experiment function.
Experiments perform training on several workers in parallel,
in other words experiments know how to invoke train and eval in a sensible
fashion for distributed training.
"""
def
_experiment_fn
(
run_config
,
hparams
):
"""Returns an Experiment."""
del
hparams
# unused arg
# create estimator
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
_resnet_model_fn
,
...
...
@@ -491,7 +503,7 @@ def main(unused_argv):
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
examples_sec_hook
=
ExamplesPerSecondHook
(
FLAGS
.
train_batch_size
,
every_n_steps
=
10
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment