Commit b6161f67 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Enable checkpoint.

PiperOrigin-RevId: 286324485
parent caa5158f
......@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from absl import logging
......@@ -253,6 +255,14 @@ def run(flags_obj):
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale)
current_step = 0
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Load checkpoint %s", latest_checkpoint)
current_step = optimizer.iterations.numpy()
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
......@@ -337,7 +347,7 @@ def run(flags_obj):
train_iter = iter(train_ds)
time_callback.on_train_begin()
for epoch in range(train_epochs):
for epoch in range(current_step // per_epoch_steps, train_epochs):
train_loss.reset_states()
training_accuracy.reset_states()
......@@ -375,6 +385,12 @@ def run(flags_obj):
test_accuracy.result().numpy(),
epoch + 1)
if flags_obj.enable_checkpoint_and_export:
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
'model.ckpt-{}'.format(epoch + 1)))
logging.info('Saved checkpoint to %s', checkpoint_name)
if summary_writer:
current_steps = steps_in_current_epoch + (epoch * per_epoch_steps)
with summary_writer.as_default():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment