Commit b547c6fa authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Avoid global_step binding to model. It is buggy due to tf.train.Checkpoint delayed restoration...

PiperOrigin-RevId: 333591143
parent eaff981c
...@@ -15,11 +15,6 @@ ...@@ -15,11 +15,6 @@
# ============================================================================== # ==============================================================================
"""Evaluation for Bert2Bert.""" """Evaluation for Bert2Bert."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os import os
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -114,7 +109,6 @@ def continuous_eval(strategy, ...@@ -114,7 +109,6 @@ def continuous_eval(strategy,
dtype=tf.int64, dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[]) shape=[])
model.global_step = global_step
@tf.function @tf.function
def test_step(inputs): def test_step(inputs):
...@@ -149,7 +143,7 @@ def continuous_eval(strategy, ...@@ -149,7 +143,7 @@ def continuous_eval(strategy,
eval_results = {} eval_results = {}
for latest_checkpoint in tf.train.checkpoints_iterator( for latest_checkpoint in tf.train.checkpoints_iterator(
model_dir, timeout=timeout): model_dir, timeout=timeout):
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
checkpoint.restore(latest_checkpoint).expect_partial() checkpoint.restore(latest_checkpoint).expect_partial()
logging.info("Loaded checkpoint %s", latest_checkpoint) logging.info("Loaded checkpoint %s", latest_checkpoint)
...@@ -162,7 +156,7 @@ def continuous_eval(strategy, ...@@ -162,7 +156,7 @@ def continuous_eval(strategy,
metric.update_state(func(logits.numpy(), targets.numpy())) metric.update_state(func(logits.numpy(), targets.numpy()))
with eval_summary_writer.as_default(): with eval_summary_writer.as_default():
step = model.global_step.numpy() step = global_step.numpy()
for metric, _ in metrics_and_funcs: for metric, _ in metrics_and_funcs:
eval_results[metric.name] = metric.result().numpy().astype(float) eval_results[metric.name] = metric.result().numpy().astype(float)
tf.summary.scalar( tf.summary.scalar(
......
...@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None): ...@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None):
FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint) FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
opt = optimizer.create_optimizer(params) opt = optimizer.create_optimizer(params)
trainer = Trainer(model, params) trainer = Trainer(model, params)
model.global_step = opt.iterations
trainer.compile( trainer.compile(
optimizer=opt, optimizer=opt,
...@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None): ...@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None):
summary_dir = os.path.join(FLAGS.model_dir, "summaries") summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard( summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, update_freq=max(100, FLAGS.steps_per_loop)) summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt) checkpoint = tf.train.Checkpoint(
model=model, optimizer=opt, global_step=opt.iterations)
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
directory=FLAGS.model_dir, directory=FLAGS.model_dir,
max_to_keep=10, max_to_keep=10,
step_counter=model.global_step, step_counter=opt.iterations,
checkpoint_interval=FLAGS.checkpoint_interval) checkpoint_interval=FLAGS.checkpoint_interval)
if checkpoint_manager.restore_or_initialize(): if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s", logging.info("Training restored from the checkpoints in: %s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment