Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b6161f67
Commit
b6161f67
authored
Dec 18, 2019
by
A. Unique TensorFlower
Browse files
Enable checkpoint.
PiperOrigin-RevId: 286324485
parent
caa5158f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
1 deletion
+17
-1
official/vision/image_classification/resnet_ctl_imagenet_main.py
...l/vision/image_classification/resnet_ctl_imagenet_main.py
+17
-1
No files found.
official/vision/image_classification/resnet_ctl_imagenet_main.py
View file @
b6161f67
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
...
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
@@ -253,6 +255,14 @@ def run(flags_obj):
...
@@ -253,6 +255,14 @@ def run(flags_obj):
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
optimizer
,
loss_scale
)
optimizer
,
loss_scale
)
current_step
=
0
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
optimizer
)
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
flags_obj
.
model_dir
)
if
latest_checkpoint
:
checkpoint
.
restore
(
latest_checkpoint
)
logging
.
info
(
"Load checkpoint %s"
,
latest_checkpoint
)
current_step
=
optimizer
.
iterations
.
numpy
()
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'train_loss'
,
dtype
=
tf
.
float32
)
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'train_loss'
,
dtype
=
tf
.
float32
)
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'training_accuracy'
,
dtype
=
tf
.
float32
)
'training_accuracy'
,
dtype
=
tf
.
float32
)
...
@@ -337,7 +347,7 @@ def run(flags_obj):
...
@@ -337,7 +347,7 @@ def run(flags_obj):
train_iter
=
iter
(
train_ds
)
train_iter
=
iter
(
train_ds
)
time_callback
.
on_train_begin
()
time_callback
.
on_train_begin
()
for
epoch
in
range
(
train_epochs
):
for
epoch
in
range
(
current_step
//
per_epoch_steps
,
train_epochs
):
train_loss
.
reset_states
()
train_loss
.
reset_states
()
training_accuracy
.
reset_states
()
training_accuracy
.
reset_states
()
...
@@ -375,6 +385,12 @@ def run(flags_obj):
...
@@ -375,6 +385,12 @@ def run(flags_obj):
test_accuracy
.
result
().
numpy
(),
test_accuracy
.
result
().
numpy
(),
epoch
+
1
)
epoch
+
1
)
if
flags_obj
.
enable_checkpoint_and_export
:
checkpoint_name
=
checkpoint
.
save
(
os
.
path
.
join
(
flags_obj
.
model_dir
,
'model.ckpt-{}'
.
format
(
epoch
+
1
)))
logging
.
info
(
'Saved checkpoint to %s'
,
checkpoint_name
)
if
summary_writer
:
if
summary_writer
:
current_steps
=
steps_in_current_epoch
+
(
epoch
*
per_epoch_steps
)
current_steps
=
steps_in_current_epoch
+
(
epoch
*
per_epoch_steps
)
with
summary_writer
.
as_default
():
with
summary_writer
.
as_default
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment