Commit b547c6fa authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Avoid global_step binding to model. It is buggy due to tf.train.Checkpoint delayed restoration...

PiperOrigin-RevId: 333591143
parent eaff981c
......@@ -15,11 +15,6 @@
# ==============================================================================
"""Evaluation for Bert2Bert."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
# Import libraries
from absl import logging
......@@ -114,7 +109,6 @@ def continuous_eval(strategy,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
model.global_step = global_step
@tf.function
def test_step(inputs):
......@@ -149,7 +143,7 @@ def continuous_eval(strategy,
eval_results = {}
for latest_checkpoint in tf.train.checkpoints_iterator(
model_dir, timeout=timeout):
checkpoint = tf.train.Checkpoint(model=model)
checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
checkpoint.restore(latest_checkpoint).expect_partial()
logging.info("Loaded checkpoint %s", latest_checkpoint)
......@@ -162,7 +156,7 @@ def continuous_eval(strategy,
metric.update_state(func(logits.numpy(), targets.numpy()))
with eval_summary_writer.as_default():
step = model.global_step.numpy()
step = global_step.numpy()
for metric, _ in metrics_and_funcs:
eval_results[metric.name] = metric.result().numpy().astype(float)
tf.summary.scalar(
......
......@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None):
FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
opt = optimizer.create_optimizer(params)
trainer = Trainer(model, params)
model.global_step = opt.iterations
trainer.compile(
optimizer=opt,
......@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None):
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
checkpoint = tf.train.Checkpoint(
model=model, optimizer=opt, global_step=opt.iterations)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=FLAGS.model_dir,
max_to_keep=10,
step_counter=model.global_step,
step_counter=opt.iterations,
checkpoint_interval=FLAGS.checkpoint_interval)
if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment