Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
02c6f1ac
Commit
02c6f1ac
authored
Aug 28, 2017
by
Toby Boyd
Browse files
Changed to params
parent
abeb0356
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
21 deletions
+20
-21
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+20
-21
No files found.
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
02c6f1ac
...
...
@@ -44,8 +44,9 @@ import tensorflow as tf
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
get_model_fn
(
num_gpus
,
variable_strategy
,
data_format
,
num_workers
):
def
get_model_fn
(
num_gpus
,
variable_strategy
,
num_workers
):
"""Returns a function that will build the resnet model."""
def
_resnet_model_fn
(
features
,
labels
,
mode
,
params
):
"""Resnet model body.
...
...
@@ -73,6 +74,16 @@ def get_model_fn(num_gpus, variable_strategy, data_format, num_workers):
tower_gradvars
=
[]
tower_preds
=
[]
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
data_format
=
params
.
data_format
if
not
data_format
:
if
num_gpus
==
0
:
data_format
=
'channels_last'
else
:
data_format
=
'channels_first'
if
num_gpus
==
0
:
num_devices
=
1
device_type
=
'cpu'
...
...
@@ -276,7 +287,6 @@ def input_fn(data_dir,
def
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
data_format
,
use_distortion_for_training
=
True
):
"""Returns an Experiment function.
...
...
@@ -291,7 +301,6 @@ def get_experiment_fn(data_dir,
num_gpus: int. Number of GPUs on each worker.
variable_strategy: String. CPU to use CPU as the parameter server
and GPU to use the GPUs as the parameter server.
data_format: String. channels_first or channels_last.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
...
...
@@ -338,11 +347,10 @@ def get_experiment_fn(data_dir,
hooks
=
[
logging_hook
,
examples_sec_hook
]
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
get_model_fn
(
num_gpus
,
variable_strategy
,
data_format
,
model_fn
=
get_model_fn
(
num_gpus
,
variable_strategy
,
run_config
.
num_worker_replicas
or
1
),
config
=
run_config
,
params
=
hparams
)
params
=
hparams
)
# Create experiment.
experiment
=
tf
.
contrib
.
learn
.
Experiment
(
...
...
@@ -354,25 +362,17 @@ def get_experiment_fn(data_dir,
# Adding hooks to be used by the estimator on training modes
experiment
.
extend_train_hooks
(
hooks
)
return
experiment
return
_experiment_fn
def
main
(
job_dir
,
data_dir
,
num_gpus
,
variable_strategy
,
data_format
,
def
main
(
job_dir
,
data_dir
,
num_gpus
,
variable_strategy
,
use_distortion_for_training
,
log_device_placement
,
num_intra_threads
,
**
hparams
):
# The env variable is on deprecation path, default is set to off.
os
.
environ
[
'TF_SYNC_ON_FINISH'
]
=
'0'
os
.
environ
[
'TF_ENABLE_WINOGRAD_NONFUSED'
]
=
'1'
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
if
not
data_format
:
if
num_gpus
==
0
:
data_format
=
'channels_last'
else
:
data_format
=
'channels_first'
# Session configuration.
sess_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
...
...
@@ -383,7 +383,7 @@ def main(job_dir, data_dir, num_gpus, variable_strategy, data_format,
config
=
cifar10_utils
.
RunConfig
(
session_config
=
sess_config
,
model_dir
=
job_dir
)
tf
.
contrib
.
learn
.
learn_runner
.
run
(
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
data_format
,
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
use_distortion_for_training
),
run_config
=
config
,
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
hparams
))
...
...
@@ -509,8 +509,7 @@ if __name__ == '__main__':
raise
ValueError
(
'Invalid GPU count:
\"
--num-gpus
\"
must be 0 or a positive integer.'
)
if
args
.
num_gpus
==
0
and
args
.
variable_strategy
==
'GPU'
:
raise
ValueError
(
'num-gpus=0, CPU must be used as parameter server. Set'
raise
ValueError
(
'num-gpus=0, CPU must be used as parameter server. Set'
'--variable-strategy=CPU.'
)
if
(
args
.
num_layers
-
2
)
%
6
!=
0
:
raise
ValueError
(
'Invalid --num-layers parameter.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment