Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0864b2a4
Unverified
Commit
0864b2a4
authored
Aug 16, 2018
by
Niru Maheswaranathan
Committed by
GitHub
Aug 16, 2018
Browse files
Merge pull request #5084 from yesmung/master
Modify flag name for the checkpoint path
parents
468d8bb6
2949cfd8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
research/learning_unsupervised_learning/run_eval.py
research/learning_unsupervised_learning/run_eval.py
+7
-7
No files found.
research/learning_unsupervised_learning/run_eval.py
View file @
0864b2a4
...
...
@@ -35,13 +35,13 @@ import sonnet as snt
from
tensorflow.contrib.framework.python.framework
import
checkpoint_utils
flags
.
DEFINE_string
(
"checkpoint"
,
None
,
"Dir to load pretrained update rule from"
)
flags
.
DEFINE_string
(
"checkpoint
_dir
"
,
None
,
"Dir to load pretrained update rule from"
)
flags
.
DEFINE_string
(
"train_log_dir"
,
None
,
"Training log directory"
)
FLAGS
=
flags
.
FLAGS
def
train
(
train_log_dir
,
checkpoint
,
eval_every_n_steps
=
10
,
num_steps
=
3000
):
def
train
(
train_log_dir
,
checkpoint
_dir
,
eval_every_n_steps
=
10
,
num_steps
=
3000
):
dataset_fn
=
datasets
.
mnist
.
TinyMnist
w_learner_fn
=
architectures
.
more_local_weight_update
.
MoreLocalWeightUpdateWLearner
theta_process_fn
=
architectures
.
more_local_weight_update
.
MoreLocalWeightUpdateProcess
...
...
@@ -77,8 +77,8 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
summary_op
=
tf
.
summary
.
merge_all
()
file_writer
=
summary_utils
.
LoggingFileWriter
(
train_log_dir
,
regexes
=
[
".*"
])
if
checkpoint
:
str_var_list
=
checkpoint_utils
.
list_variables
(
checkpoint
)
if
checkpoint
_dir
:
str_var_list
=
checkpoint_utils
.
list_variables
(
checkpoint
_dir
)
name_to_v_map
=
{
v
.
op
.
name
:
v
for
v
in
tf
.
all_variables
()}
var_list
=
[
name_to_v_map
[
vn
]
for
vn
,
_
in
str_var_list
if
vn
in
name_to_v_map
...
...
@@ -99,9 +99,9 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
# global step should be restored from the evals job checkpoint or zero for fresh.
step
=
sess
.
run
(
global_step
)
if
step
==
0
and
checkpoint
:
if
step
==
0
and
checkpoint
_dir
:
tf
.
logging
.
info
(
"force restore"
)
saver
.
restore
(
sess
,
checkpoint
)
saver
.
restore
(
sess
,
checkpoint
_dir
)
tf
.
logging
.
info
(
"force restore done"
)
sess
.
run
(
reset_global_step
)
step
=
sess
.
run
(
global_step
)
...
...
@@ -115,7 +115,7 @@ def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
def
main
(
argv
):
train
(
FLAGS
.
train_log_dir
,
FLAGS
.
checkpoint
)
train
(
FLAGS
.
train_log_dir
,
FLAGS
.
checkpoint
_dir
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment