Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cd5e9b7c
Commit
cd5e9b7c
authored
Sep 23, 2016
by
Christopher Shallue
Browse files
Fix a bug in the im2txt code where the Saver is created before the
optimizer.
parent
71f239fd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
19 deletions
+16
-19
im2txt/im2txt/configuration.py
im2txt/im2txt/configuration.py
+3
-4
im2txt/im2txt/evaluate.py
im2txt/im2txt/evaluate.py
+7
-3
im2txt/im2txt/inference_utils/inference_wrapper_base.py
im2txt/im2txt/inference_utils/inference_wrapper_base.py
+2
-4
im2txt/im2txt/show_and_tell_model.py
im2txt/im2txt/show_and_tell_model.py
+0
-7
im2txt/im2txt/train.py
im2txt/im2txt/train.py
+4
-1
No files found.
im2txt/im2txt/configuration.py
View file @
cd5e9b7c
...
@@ -77,10 +77,6 @@ class ModelConfig(object):
...
@@ -77,10 +77,6 @@ class ModelConfig(object):
# If < 1.0, the dropout keep probability applied to LSTM variables.
# If < 1.0, the dropout keep probability applied to LSTM variables.
self
.
lstm_dropout_keep_prob
=
0.7
self
.
lstm_dropout_keep_prob
=
0.7
# How many model checkpoints to keep.
self
.
max_checkpoints_to_keep
=
5
self
.
keep_checkpoint_every_n_hours
=
10000
class
TrainingConfig
(
object
):
class
TrainingConfig
(
object
):
"""Wrapper class for training hyperparameters."""
"""Wrapper class for training hyperparameters."""
...
@@ -103,3 +99,6 @@ class TrainingConfig(object):
...
@@ -103,3 +99,6 @@ class TrainingConfig(object):
# If not None, clip gradients to this value.
# If not None, clip gradients to this value.
self
.
clip_gradients
=
5.0
self
.
clip_gradients
=
5.0
# How many model checkpoints to keep.
self
.
max_checkpoints_to_keep
=
5
im2txt/im2txt/evaluate.py
View file @
cd5e9b7c
...
@@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op):
...
@@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op):
global_step
)
global_step
)
def
run_once
(
model
,
summary_writer
,
summary_op
):
def
run_once
(
model
,
saver
,
summary_writer
,
summary_op
):
"""Evaluates the latest model checkpoint.
"""Evaluates the latest model checkpoint.
Args:
Args:
model: Instance of ShowAndTellModel; the model to evaluate.
model: Instance of ShowAndTellModel; the model to evaluate.
saver: Instance of tf.train.Saver for restoring model Variables.
summary_writer: Instance of SummaryWriter.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
summary_op: Op for generating model summaries.
"""
"""
...
@@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op):
...
@@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op):
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
# Load model from checkpoint.
# Load model from checkpoint.
tf
.
logging
.
info
(
"Loading model from checkpoint: %s"
,
model_path
)
tf
.
logging
.
info
(
"Loading model from checkpoint: %s"
,
model_path
)
model
.
saver
.
restore
(
sess
,
model_path
)
saver
.
restore
(
sess
,
model_path
)
global_step
=
tf
.
train
.
global_step
(
sess
,
model
.
global_step
.
name
)
global_step
=
tf
.
train
.
global_step
(
sess
,
model
.
global_step
.
name
)
tf
.
logging
.
info
(
"Successfully loaded %s at global step = %d."
,
tf
.
logging
.
info
(
"Successfully loaded %s at global step = %d."
,
os
.
path
.
basename
(
model_path
),
global_step
)
os
.
path
.
basename
(
model_path
),
global_step
)
...
@@ -166,6 +167,9 @@ def run():
...
@@ -166,6 +167,9 @@ def run():
model
=
show_and_tell_model
.
ShowAndTellModel
(
model_config
,
mode
=
"eval"
)
model
=
show_and_tell_model
.
ShowAndTellModel
(
model_config
,
mode
=
"eval"
)
model
.
build
()
model
.
build
()
# Create the Saver to restore model Variables.
saver
=
tf
.
train
.
Saver
()
# Create the summary operation and the summary writer.
# Create the summary operation and the summary writer.
summary_op
=
tf
.
merge_all_summaries
()
summary_op
=
tf
.
merge_all_summaries
()
summary_writer
=
tf
.
train
.
SummaryWriter
(
eval_dir
)
summary_writer
=
tf
.
train
.
SummaryWriter
(
eval_dir
)
...
@@ -177,7 +181,7 @@ def run():
...
@@ -177,7 +181,7 @@ def run():
start
=
time
.
time
()
start
=
time
.
time
()
tf
.
logging
.
info
(
"Starting evaluation at "
+
time
.
strftime
(
tf
.
logging
.
info
(
"Starting evaluation at "
+
time
.
strftime
(
"%Y-%m-%d-%H:%M:%S"
,
time
.
localtime
()))
"%Y-%m-%d-%H:%M:%S"
,
time
.
localtime
()))
run_once
(
model
,
summary_writer
,
summary_op
)
run_once
(
model
,
saver
,
summary_writer
,
summary_op
)
time_to_next_eval
=
start
+
FLAGS
.
eval_interval_secs
-
time
.
time
()
time_to_next_eval
=
start
+
FLAGS
.
eval_interval_secs
-
time
.
time
()
if
time_to_next_eval
>
0
:
if
time_to_next_eval
>
0
:
time
.
sleep
(
time_to_next_eval
)
time
.
sleep
(
time_to_next_eval
)
...
...
im2txt/im2txt/inference_utils/inference_wrapper_base.py
View file @
cd5e9b7c
...
@@ -112,10 +112,8 @@ class InferenceWrapperBase(object):
...
@@ -112,10 +112,8 @@ class InferenceWrapperBase(object):
from the checkpoint file.
from the checkpoint file.
"""
"""
tf
.
logging
.
info
(
"Building model."
)
tf
.
logging
.
info
(
"Building model."
)
model
=
self
.
build_model
(
model_config
)
self
.
build_model
(
model_config
)
saver
=
model
.
saver
saver
=
tf
.
train
.
Saver
()
if
not
saver
:
saver
=
tf
.
Saver
()
return
self
.
_create_restore_fn
(
checkpoint_path
,
saver
)
return
self
.
_create_restore_fn
(
checkpoint_path
,
saver
)
...
...
im2txt/im2txt/show_and_tell_model.py
View file @
cd5e9b7c
...
@@ -347,12 +347,6 @@ class ShowAndTellModel(object):
...
@@ -347,12 +347,6 @@ class ShowAndTellModel(object):
self
.
global_step
=
global_step
self
.
global_step
=
global_step
def
setup_saver
(
self
):
"""Sets up the Saver for loading and saving model checkpoints."""
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
config
.
max_checkpoints_to_keep
,
keep_checkpoint_every_n_hours
=
self
.
config
.
keep_checkpoint_every_n_hours
)
def
build
(
self
):
def
build
(
self
):
"""Creates all ops for training and evaluation."""
"""Creates all ops for training and evaluation."""
self
.
build_inputs
()
self
.
build_inputs
()
...
@@ -361,4 +355,3 @@ class ShowAndTellModel(object):
...
@@ -361,4 +355,3 @@ class ShowAndTellModel(object):
self
.
build_model
()
self
.
build_model
()
self
.
setup_inception_initializer
()
self
.
setup_inception_initializer
()
self
.
setup_global_step
()
self
.
setup_global_step
()
self
.
setup_saver
()
im2txt/im2txt/train.py
View file @
cd5e9b7c
...
@@ -95,6 +95,9 @@ def main(unused_argv):
...
@@ -95,6 +95,9 @@ def main(unused_argv):
clip_gradients
=
training_config
.
clip_gradients
,
clip_gradients
=
training_config
.
clip_gradients
,
learning_rate_decay_fn
=
learning_rate_decay_fn
)
learning_rate_decay_fn
=
learning_rate_decay_fn
)
# Set up the Saver for saving and restoring model checkpoints.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
training_config
.
max_checkpoints_to_keep
)
# Run training.
# Run training.
tf
.
contrib
.
slim
.
learning
.
train
(
tf
.
contrib
.
slim
.
learning
.
train
(
train_op
,
train_op
,
...
@@ -104,7 +107,7 @@ def main(unused_argv):
...
@@ -104,7 +107,7 @@ def main(unused_argv):
global_step
=
model
.
global_step
,
global_step
=
model
.
global_step
,
number_of_steps
=
FLAGS
.
number_of_steps
,
number_of_steps
=
FLAGS
.
number_of_steps
,
init_fn
=
model
.
init_fn
,
init_fn
=
model
.
init_fn
,
saver
=
model
.
saver
)
saver
=
saver
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment