Commit 08f45dc4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300203487
parent 682d36b5
......@@ -205,6 +205,12 @@ def define_transformer_flags():
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags.DEFINE_bool(
name='enable_checkpointing',
default=True,
help=flags_core.help_wrap(
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
......
......@@ -159,6 +159,7 @@ class TransformerTask(object):
params["enable_tensorboard"] = flags_obj.enable_tensorboard
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......@@ -313,10 +314,12 @@ class TransformerTask(object):
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step)
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking.
checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
raise NotImplementedError(
......@@ -397,10 +400,11 @@ class TransformerTask(object):
scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback)
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
if params["enable_checkpointing"]:
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
return callbacks
def _load_weights_if_possible(self, model, init_weight_path=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment