"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "365313edd2658e5d048d97fad04c7729deb9815b"
Commit 795a3f7d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix TF2 3D Unet to standard model garden recommended style.

PiperOrigin-RevId: 306752053
parent 5741cef6
...@@ -33,6 +33,7 @@ from official.nlp.nhnet import models ...@@ -33,6 +33,7 @@ from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -122,18 +123,6 @@ class Trainer(tf.keras.Model): ...@@ -122,18 +123,6 @@ class Trainer(tf.keras.Model):
} }
class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""
def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager
def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy()
self.checkpoint_manager.save(checkpoint_number=step_counter)
def train(params, strategy, dataset=None): def train(params, strategy, dataset=None):
"""Runs training.""" """Runs training."""
...@@ -168,7 +157,7 @@ def train(params, strategy, dataset=None): ...@@ -168,7 +157,7 @@ def train(params, strategy, dataset=None):
if checkpoint_manager.restore_or_initialize(): if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s", logging.info("Training restored from the checkpoints in: %s",
FLAGS.model_dir) FLAGS.model_dir)
checkpoint_callback = SimpleCheckpoint(checkpoint_manager) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
# Trains the model. # Trains the model.
steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval) steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
......
...@@ -164,6 +164,18 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard, ...@@ -164,6 +164,18 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch) return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch)
class SimpleCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback to save tf.train.Checkpoints."""
def __init__(self, checkpoint_manager):
super(SimpleCheckpoint, self).__init__()
self.checkpoint_manager = checkpoint_manager
def on_epoch_end(self, epoch, logs=None):
step_counter = self.checkpoint_manager._step_counter.numpy() # pylint: disable=protected-access
self.checkpoint_manager.save(checkpoint_number=step_counter)
class ProfilerCallback(tf.keras.callbacks.Callback): class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory.""" """Save profiles in specified step range to log directory."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment