Commit 795a3f7d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix TF2 3D Unet to standard model garden recommended style.

PiperOrigin-RevId: 306752053
parent 5741cef6
......@@ -33,6 +33,7 @@ from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
......@@ -122,18 +123,6 @@ class Trainer(tf.keras.Model):
}
class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""
def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager
def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy()
self.checkpoint_manager.save(checkpoint_number=step_counter)
def train(params, strategy, dataset=None):
"""Runs training."""
......@@ -168,7 +157,7 @@ def train(params, strategy, dataset=None):
if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s",
FLAGS.model_dir)
checkpoint_callback = SimpleCheckpoint(checkpoint_manager)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
# Trains the model.
steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
......
......@@ -164,6 +164,18 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch)
class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""
def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager
def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy() # pylint: disable=protected-access
self.checkpoint_manager.save(checkpoint_number=step_counter)
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment