Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
931c70a1
"tests/vscode:/vscode.git/clone" did not exist on "3ffa7b46e5d896dc35264b50325460f554556a93"
Commit
931c70a1
authored
Dec 08, 2016
by
Mustafa Ispir
Committed by
Mustafa Ispir
Dec 08, 2016
Browse files
Convert resnet model to use monitored_session
parent
a533325c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
50 deletions
+67
-50
resnet/cifar_input.py
resnet/cifar_input.py
+1
-1
resnet/resnet_main.py
resnet/resnet_main.py
+65
-48
resnet/resnet_model.py
resnet/resnet_model.py
+1
-1
No files found.
resnet/cifar_input.py
View file @
931c70a1
...
...
@@ -73,7 +73,7 @@ def build_input(dataset, data_path, batch_size, mode):
# image = tf.image.random_brightness(image, max_delta=63. / 255.)
# image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
# image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
image
=
tf
.
image
.
per_image_
whitening
(
image
)
image
=
tf
.
image
.
per_image_
standardization
(
image
)
example_queue
=
tf
.
RandomShuffleQueue
(
capacity
=
16
*
batch_size
,
...
...
resnet/resnet_main.py
View file @
931c70a1
...
...
@@ -15,8 +15,8 @@
"""ResNet Train/Eval module.
"""
import
sys
import
time
import
sys
import
cifar_input
import
numpy
as
np
...
...
@@ -26,8 +26,10 @@ import tensorflow as tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
tf
.
app
.
flags
.
DEFINE_string
(
'dataset'
,
'cifar10'
,
'cifar10 or cifar100.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'mode'
,
'train'
,
'train or eval.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'train_data_path'
,
''
,
'Filepattern for training data.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'eval_data_path'
,
''
,
'Filepattern for eval data'
)
tf
.
app
.
flags
.
DEFINE_string
(
'train_data_path'
,
''
,
'Filepattern for training data.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'eval_data_path'
,
''
,
'Filepattern for eval data'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'image_size'
,
32
,
'Image side length.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'train_dir'
,
''
,
'Directory to keep training outputs.'
)
...
...
@@ -50,50 +52,65 @@ def train(hps):
FLAGS
.
dataset
,
FLAGS
.
train_data_path
,
hps
.
batch_size
,
FLAGS
.
mode
)
model
=
resnet_model
.
ResNet
(
hps
,
images
,
labels
,
FLAGS
.
mode
)
model
.
build_graph
()
summary_writer
=
tf
.
train
.
SummaryWriter
(
FLAGS
.
train_dir
)
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
log_root
,
is_chief
=
True
,
summary_op
=
None
,
save_summaries_secs
=
60
,
save_model_secs
=
300
,
global_step
=
model
.
global_step
)
sess
=
sv
.
prepare_or_wait_for_session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
step
=
0
lrn_rate
=
0.1
while
not
sv
.
should_stop
():
(
_
,
summaries
,
loss
,
predictions
,
truth
,
train_step
)
=
sess
.
run
(
[
model
.
train_op
,
model
.
summaries
,
model
.
cost
,
model
.
predictions
,
model
.
labels
,
model
.
global_step
],
feed_dict
=
{
model
.
lrn_rate
:
lrn_rate
})
if
train_step
<
40000
:
lrn_rate
=
0.1
elif
train_step
<
60000
:
lrn_rate
=
0.01
elif
train_step
<
80000
:
lrn_rate
=
0.001
else
:
lrn_rate
=
0.0001
truth
=
np
.
argmax
(
truth
,
axis
=
1
)
predictions
=
np
.
argmax
(
predictions
,
axis
=
1
)
precision
=
np
.
mean
(
truth
==
predictions
)
step
+=
1
if
step
%
100
==
0
:
precision_summ
=
tf
.
Summary
()
precision_summ
.
value
.
add
(
tag
=
'Precision'
,
simple_value
=
precision
)
summary_writer
.
add_summary
(
precision_summ
,
train_step
)
summary_writer
.
add_summary
(
summaries
,
train_step
)
tf
.
logging
.
info
(
'loss: %.3f, precision: %.3f
\n
'
%
(
loss
,
precision
))
summary_writer
.
flush
()
sv
.
Stop
()
param_stats
=
tf
.
contrib
.
tfprof
.
model_analyzer
.
print_model_analysis
(
tf
.
get_default_graph
(),
tfprof_options
=
tf
.
contrib
.
tfprof
.
model_analyzer
.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS
)
sys
.
stdout
.
write
(
'total_params: %d
\n
'
%
param_stats
.
total_parameters
)
tf
.
contrib
.
tfprof
.
model_analyzer
.
print_model_analysis
(
tf
.
get_default_graph
(),
tfprof_options
=
tf
.
contrib
.
tfprof
.
model_analyzer
.
FLOAT_OPS_OPTIONS
)
truth
=
tf
.
argmax
(
model
.
labels
,
axis
=
1
)
predictions
=
tf
.
argmax
(
model
.
predictions
,
axis
=
1
)
precision
=
tf
.
reduce_mean
(
tf
.
to_float
(
tf
.
equal
(
predictions
,
truth
)))
summary_hook
=
tf
.
train
.
SummarySaverHook
(
save_steps
=
100
,
output_dir
=
FLAGS
.
train_dir
,
summary_op
=
[
model
.
summaries
,
tf
.
summary
.
scalar
(
'Precision'
,
precision
)])
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
{
'step'
:
model
.
global_step
,
'loss'
:
model
.
cost
,
'precision'
:
precision
},
every_n_iter
=
100
)
class
_LearningRateSetterHook
(
tf
.
train
.
SessionRunHook
):
"""Sets learning_rate based on global step."""
def
begin
(
self
):
self
.
_lrn_rate
=
0.1
def
before_run
(
self
,
run_context
):
return
tf
.
train
.
SessionRunArgs
(
model
.
global_step
,
# Asks for global step value.
feed_dict
=
{
model
.
lrn_rate
:
self
.
_lrn_rate
})
# Sets learning rate
def
after_run
(
self
,
run_context
,
run_values
):
train_step
=
run_values
.
results
if
train_step
<
40000
:
self
.
_lrn_rate
=
0.1
elif
train_step
<
60000
:
self
.
_lrn_rate
=
0.01
elif
train_step
<
80000
:
self
.
_lrn_rate
=
0.001
else
:
self
.
_lrn_rate
=
0.0001
with
tf
.
train
.
MonitoredTrainingSession
(
checkpoint_dir
=
FLAGS
.
log_root
,
hooks
=
[
logging_hook
,
_LearningRateSetterHook
()],
chief_only_hooks
=
[
summary_hook
],
# Since we provide a SummarySaverHook, we need to disable default
# SummarySaverHook. To do that we set save_summaries_steps to 0.
save_summaries_steps
=
0
,
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
as
mon_sess
:
while
not
mon_sess
.
should_stop
():
mon_sess
.
run
(
model
.
train_op
)
def
evaluate
(
hps
):
...
...
@@ -103,7 +120,7 @@ def evaluate(hps):
model
=
resnet_model
.
ResNet
(
hps
,
images
,
labels
,
FLAGS
.
mode
)
model
.
build_graph
()
saver
=
tf
.
train
.
Saver
()
summary_writer
=
tf
.
train
.
S
ummaryWriter
(
FLAGS
.
eval_dir
)
summary_writer
=
tf
.
s
ummary
.
File
Writer
(
FLAGS
.
eval_dir
)
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
tf
.
train
.
start_queue_runners
(
sess
)
...
...
resnet/resnet_model.py
View file @
931c70a1
...
...
@@ -55,7 +55,7 @@ class ResNet(object):
def
build_graph
(
self
):
"""Build a whole graph for the model."""
self
.
global_step
=
tf
.
Variable
(
0
,
name
=
'global_step'
,
trainable
=
False
)
self
.
global_step
=
tf
.
contrib
.
framework
.
get_or_create_global_step
(
)
self
.
_build_model
()
if
self
.
mode
==
'train'
:
self
.
_build_train_op
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment