Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40f8e23e
Commit
40f8e23e
authored
Mar 06, 2018
by
Allen Lavoie
Browse files
Update the eager MNIST example to use object-based checkpointing
parent
d4a4dd04
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
20 deletions
+22
-20
official/mnist/mnist_eager.py
official/mnist/mnist_eager.py
+20
-19
official/mnist/mnist_eager_test.py
official/mnist/mnist_eager_test.py
+2
-1
No files found.
official/mnist/mnist_eager.py
View file @
40f8e23e
...
@@ -53,14 +53,13 @@ def compute_accuracy(logits, labels):
...
@@ -53,14 +53,13 @@ def compute_accuracy(logits, labels):
tf
.
cast
(
tf
.
equal
(
predictions
,
labels
),
dtype
=
tf
.
float32
))
/
batch_size
tf
.
cast
(
tf
.
equal
(
predictions
,
labels
),
dtype
=
tf
.
float32
))
/
batch_size
def
train
(
model
,
optimizer
,
dataset
,
log_interval
=
None
):
def
train
(
model
,
optimizer
,
dataset
,
step_counter
,
log_interval
=
None
):
"""Trains model on `dataset` using `optimizer`."""
"""Trains model on `dataset` using `optimizer`."""
global_step
=
tf
.
train
.
get_or_create_global_step
()
start
=
time
.
time
()
start
=
time
.
time
()
for
(
batch
,
(
images
,
labels
))
in
enumerate
(
tfe
.
Iterator
(
dataset
)):
for
(
batch
,
(
images
,
labels
))
in
enumerate
(
tfe
.
Iterator
(
dataset
)):
with
tf
.
contrib
.
summary
.
record_summaries_every_n_global_steps
(
10
):
with
tf
.
contrib
.
summary
.
record_summaries_every_n_global_steps
(
10
,
global_step
=
step_counter
):
# Record the operations used to compute the loss given the input,
# Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables
# so that the gradient of the loss with respect to the variables
# can be computed.
# can be computed.
...
@@ -71,7 +70,7 @@ def train(model, optimizer, dataset, log_interval=None):
...
@@ -71,7 +70,7 @@ def train(model, optimizer, dataset, log_interval=None):
tf
.
contrib
.
summary
.
scalar
(
'accuracy'
,
compute_accuracy
(
logits
,
labels
))
tf
.
contrib
.
summary
.
scalar
(
'accuracy'
,
compute_accuracy
(
logits
,
labels
))
grads
=
tape
.
gradient
(
loss_value
,
model
.
variables
)
grads
=
tape
.
gradient
(
loss_value
,
model
.
variables
)
optimizer
.
apply_gradients
(
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
variables
),
global_step
=
global_step
)
zip
(
grads
,
model
.
variables
),
global_step
=
step_counter
)
if
log_interval
and
batch
%
log_interval
==
0
:
if
log_interval
and
batch
%
log_interval
==
0
:
rate
=
log_interval
/
(
time
.
time
()
-
start
)
rate
=
log_interval
/
(
time
.
time
()
-
start
)
print
(
'Step #%d
\t
Loss: %.6f (%d steps/sec)'
%
(
batch
,
loss_value
,
rate
))
print
(
'Step #%d
\t
Loss: %.6f (%d steps/sec)'
%
(
batch
,
loss_value
,
rate
))
...
@@ -128,23 +127,25 @@ def main(_):
...
@@ -128,23 +127,25 @@ def main(_):
test_summary_writer
=
tf
.
contrib
.
summary
.
create_file_writer
(
test_summary_writer
=
tf
.
contrib
.
summary
.
create_file_writer
(
test_dir
,
flush_millis
=
10000
,
name
=
'test'
)
test_dir
,
flush_millis
=
10000
,
name
=
'test'
)
checkpoint_prefix
=
os
.
path
.
join
(
FLAGS
.
checkpoint_dir
,
'ckpt'
)
checkpoint_prefix
=
os
.
path
.
join
(
FLAGS
.
checkpoint_dir
,
'ckpt'
)
step_counter
=
tf
.
train
.
get_or_create_global_step
()
# Train and evaluate for 11 epochs.
checkpoint
=
tfe
.
Checkpoint
(
model
=
model
,
optimizer
=
optimizer
,
step_counter
=
step_counter
)
# Restore variables on creation if a checkpoint exists.
checkpoint
.
restore
(
tf
.
train
.
latest_checkpoint
(
FLAGS
.
checkpoint_dir
))
# Train and evaluate for 10 epochs.
with
tf
.
device
(
device
):
with
tf
.
device
(
device
):
for
epoch
in
range
(
1
,
11
):
for
_
in
range
(
10
):
with
tfe
.
restore_variables_on_create
(
start
=
time
.
time
()
tf
.
train
.
latest_checkpoint
(
FLAGS
.
checkpoint_dir
)):
with
summary_writer
.
as_default
():
global_step
=
tf
.
train
.
get_or_create_global_step
()
train
(
model
,
optimizer
,
train_ds
,
step_counter
,
FLAGS
.
log_interval
)
start
=
time
.
time
()
end
=
time
.
time
()
with
summary_writer
.
as_default
():
print
(
'
\n
Train time for epoch #%d (%d total steps): %f'
%
train
(
model
,
optimizer
,
train_ds
,
FLAGS
.
log_interval
)
(
checkpoint
.
save_counter
.
numpy
()
+
1
,
end
=
time
.
time
()
step_counter
.
numpy
(),
print
(
'
\n
Train time for epoch #%d (global step %d): %f'
%
end
-
start
))
(
epoch
,
global_step
.
numpy
(),
end
-
start
))
with
test_summary_writer
.
as_default
():
with
test_summary_writer
.
as_default
():
test
(
model
,
test_ds
)
test
(
model
,
test_ds
)
all_variables
=
(
model
.
variables
+
optimizer
.
variables
()
+
[
global_step
])
checkpoint
.
save
(
checkpoint_prefix
)
tfe
.
Saver
(
all_variables
).
save
(
checkpoint_prefix
,
global_step
=
global_step
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/mnist/mnist_eager_test.py
View file @
40f8e23e
...
@@ -46,7 +46,8 @@ def train(defun=False):
...
@@ -46,7 +46,8 @@ def train(defun=False):
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
dataset
=
random_dataset
()
dataset
=
random_dataset
()
with
tf
.
device
(
device
()):
with
tf
.
device
(
device
()):
mnist_eager
.
train
(
model
,
optimizer
,
dataset
)
mnist_eager
.
train
(
model
,
optimizer
,
dataset
,
step_counter
=
tf
.
train
.
get_or_create_global_step
())
def
evaluate
(
defun
=
False
):
def
evaluate
(
defun
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment