Commit b6161f67 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Enable checkpoint.

PiperOrigin-RevId: 286324485
parent caa5158f
...@@ -18,6 +18,8 @@ from __future__ import absolute_import ...@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -253,6 +255,14 @@ def run(flags_obj): ...@@ -253,6 +255,14 @@ def run(flags_obj):
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale) optimizer, loss_scale)
current_step = 0
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Load checkpoint %s", latest_checkpoint)
current_step = optimizer.iterations.numpy()
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32) 'training_accuracy', dtype=tf.float32)
...@@ -337,7 +347,7 @@ def run(flags_obj): ...@@ -337,7 +347,7 @@ def run(flags_obj):
train_iter = iter(train_ds) train_iter = iter(train_ds)
time_callback.on_train_begin() time_callback.on_train_begin()
for epoch in range(train_epochs): for epoch in range(current_step // per_epoch_steps, train_epochs):
train_loss.reset_states() train_loss.reset_states()
training_accuracy.reset_states() training_accuracy.reset_states()
...@@ -375,6 +385,12 @@ def run(flags_obj): ...@@ -375,6 +385,12 @@ def run(flags_obj):
test_accuracy.result().numpy(), test_accuracy.result().numpy(),
epoch + 1) epoch + 1)
if flags_obj.enable_checkpoint_and_export:
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
'model.ckpt-{}'.format(epoch + 1)))
logging.info('Saved checkpoint to %s', checkpoint_name)
if summary_writer: if summary_writer:
current_steps = steps_in_current_epoch + (epoch * per_epoch_steps) current_steps = steps_in_current_epoch + (epoch * per_epoch_steps)
with summary_writer.as_default(): with summary_writer.as_default():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment