Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
612ec83d
Commit
612ec83d
authored
Jul 18, 2018
by
Asim Shankar
Browse files
[official/mnist]: Avoid some now unnecessary 'tfe' symbols.
parent
a141d020
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
official/mnist/mnist_eager.py
official/mnist/mnist_eager.py
+5
-4
No files found.
official/mnist/mnist_eager.py
View file @
612ec83d
...
...
@@ -33,7 +33,6 @@ import time
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow.contrib.eager
as
tfe
# pylint: enable=g-bad-import-order
from
official.mnist
import
dataset
as
mnist_dataset
...
...
@@ -42,6 +41,8 @@ from official.utils.flags import core as flags_core
from
official.utils.misc
import
model_helpers
tfe
=
tf
.
contrib
.
eager
def
loss
(
logits
,
labels
):
return
tf
.
reduce_mean
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
...
...
@@ -60,7 +61,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
"""Trains model on `dataset` using `optimizer`."""
start
=
time
.
time
()
for
(
batch
,
(
images
,
labels
))
in
enumerate
(
tfe
.
Iterator
(
dataset
)
)
:
for
(
batch
,
(
images
,
labels
))
in
enumerate
(
dataset
):
with
tf
.
contrib
.
summary
.
record_summaries_every_n_global_steps
(
10
,
global_step
=
step_counter
):
# Record the operations used to compute the loss given the input,
...
...
@@ -85,7 +86,7 @@ def test(model, dataset):
avg_loss
=
tfe
.
metrics
.
Mean
(
'loss'
)
accuracy
=
tfe
.
metrics
.
Accuracy
(
'accuracy'
)
for
(
images
,
labels
)
in
tfe
.
Iterator
(
dataset
)
:
for
(
images
,
labels
)
in
dataset
:
logits
=
model
(
images
,
training
=
False
)
avg_loss
(
loss
(
logits
,
labels
))
accuracy
(
...
...
@@ -145,7 +146,7 @@ def run_mnist_eager(flags_obj):
# Create and restore checkpoint (if one exists on the path)
checkpoint_prefix
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'ckpt'
)
step_counter
=
tf
.
train
.
get_or_create_global_step
()
checkpoint
=
tf
e
.
Checkpoint
(
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
optimizer
,
step_counter
=
step_counter
)
# Restore variables on creation if a checkpoint exists.
checkpoint
.
restore
(
tf
.
train
.
latest_checkpoint
(
flags_obj
.
model_dir
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment