Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b36c01b6
Commit
b36c01b6
authored
Jul 06, 2017
by
Marianne Linhares Monteiro
Committed by
GitHub
Jul 06, 2017
Browse files
Adding run_experiment option
parent
d71cbd0c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
21 deletions
+42
-21
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+42
-21
No files found.
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
b36c01b6
...
@@ -72,6 +72,11 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
...
@@ -72,6 +72,11 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf
.
flags
.
DEFINE_boolean
(
'use_distortion_for_training'
,
True
,
tf
.
flags
.
DEFINE_boolean
(
'use_distortion_for_training'
,
True
,
'If doing image distortion for training.'
)
'If doing image distortion for training.'
)
tf
.
flags
.
DEFINE_boolean
(
'run_experiment'
,
False
,
"If True will run an experiment,"
"otherwise will run training and evaluatio"
"using the estimator's methods"
)
# Perf flags
# Perf flags
tf
.
flags
.
DEFINE_integer
(
'num_intra_threads'
,
1
,
tf
.
flags
.
DEFINE_integer
(
'num_intra_threads'
,
1
,
"""Number of threads to use for intra-op parallelism.
"""Number of threads to use for intra-op parallelism.
...
@@ -359,6 +364,19 @@ def input_fn(subset, num_shards):
...
@@ -359,6 +364,19 @@ def input_fn(subset, num_shards):
label_shards
=
[
tf
.
parallel_stack
(
x
)
for
x
in
label_shards
]
label_shards
=
[
tf
.
parallel_stack
(
x
)
for
x
in
label_shards
]
return
feature_shards
,
label_shards
return
feature_shards
,
label_shards
# create experiment
def
experiment_fn
(
run_config
,
hparams
):
# create estimator
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
_resnet_model_fn
,
config
=
run_config
)
return
tf
.
contrib
.
learn
.
Experiment
(
classifier
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
test_input_fn
,
train_steps
=
FLAGS
.
train_steps
,
eval_steps
=
num_eval_examples
//
FLAGS
.
eval_batch_size
)
def
main
(
unused_argv
):
def
main
(
unused_argv
):
# The env variable is on deprecation path, default is set to off.
# The env variable is on deprecation path, default is set to off.
...
@@ -381,7 +399,7 @@ def main(unused_argv):
...
@@ -381,7 +399,7 @@ def main(unused_argv):
if
num_eval_examples
%
FLAGS
.
eval_batch_size
!=
0
:
if
num_eval_examples
%
FLAGS
.
eval_batch_size
!=
0
:
raise
ValueError
(
'validation set size must be multiple of eval_batch_size'
)
raise
ValueError
(
'validation set size must be multiple of eval_batch_size'
)
config
=
tf
.
estimator
.
RunConfig
(
)
config
=
tf
.
contrib
.
learn
.
RunConfig
(
model_dir
=
FLAGS
.
model_dir
)
sess_config
=
tf
.
ConfigProto
()
sess_config
=
tf
.
ConfigProto
()
sess_config
.
allow_soft_placement
=
True
sess_config
.
allow_soft_placement
=
True
sess_config
.
log_device_placement
=
FLAGS
.
log_device_placement
sess_config
.
log_device_placement
=
FLAGS
.
log_device_placement
...
@@ -390,26 +408,29 @@ def main(unused_argv):
...
@@ -390,26 +408,29 @@ def main(unused_argv):
sess_config
.
gpu_options
.
force_gpu_compatible
=
FLAGS
.
force_gpu_compatible
sess_config
.
gpu_options
.
force_gpu_compatible
=
FLAGS
.
force_gpu_compatible
config
=
config
.
replace
(
session_config
=
sess_config
)
config
=
config
.
replace
(
session_config
=
sess_config
)
classifier
=
tf
.
estimator
.
Estimator
(
if
FLAGS
.
run_experiment
:
model_fn
=
_resnet_model_fn
,
model_dir
=
FLAGS
.
model_dir
,
config
=
config
)
tf
.
contrib
.
learn
.
learn_runner
.
run
(
experiment_fn
,
run_config
=
run_config
)
else
:
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
}
classifier
=
tf
.
estimator
.
Estimator
(
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
model_fn
=
_resnet_model_fn
,
config
=
config
)
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
}
print
(
'Starting to train...'
)
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
classifier
.
train
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
input_fn
=
functools
.
partial
(
input_fn
,
subset
=
'train'
,
num_shards
=
FLAGS
.
num_gpus
),
print
(
'Starting to train...'
)
steps
=
FLAGS
.
train_steps
,
classifier
.
train
(
hooks
=
[
logging_hook
])
input_fn
=
functools
.
partial
(
input_fn
,
subset
=
'train'
,
num_shards
=
FLAGS
.
num_gpus
),
print
(
'Starting to evaluate...'
)
steps
=
FLAGS
.
train_steps
,
eval_results
=
classifier
.
evaluate
(
hooks
=
[
logging_hook
])
input_fn
=
functools
.
partial
(
input_fn
,
subset
=
'eval'
,
num_shards
=
FLAGS
.
num_gpus
),
print
(
'Starting to evaluate...'
)
steps
=
num_eval_examples
//
FLAGS
.
eval_batch_size
)
eval_results
=
classifier
.
evaluate
(
print
(
eval_results
)
input_fn
=
functools
.
partial
(
input_fn
,
subset
=
'eval'
,
num_shards
=
FLAGS
.
num_gpus
),
steps
=
num_eval_examples
//
FLAGS
.
eval_batch_size
)
print
(
eval_results
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment