Commit 08f45dc4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300203487
parent 682d36b5
...@@ -205,6 +205,12 @@ def define_transformer_flags(): ...@@ -205,6 +205,12 @@ def define_transformer_flags():
'use padded_decode, it has not been tested. In addition, this method ' 'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with ' 'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.')) 'the max sequence length.'))
flags.DEFINE_bool(
name='enable_checkpointing',
default=True,
help=flags_core.help_wrap(
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
......
...@@ -159,6 +159,7 @@ class TransformerTask(object): ...@@ -159,6 +159,7 @@ class TransformerTask(object):
params["enable_tensorboard"] = flags_obj.enable_tensorboard params["enable_tensorboard"] = flags_obj.enable_tensorboard
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
...@@ -313,6 +314,8 @@ class TransformerTask(object): ...@@ -313,6 +314,8 @@ class TransformerTask(object):
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(), tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step) current_step)
if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking.
checkpoint_name = checkpoint.save( checkpoint_name = checkpoint.save(
os.path.join(flags_obj.model_dir, os.path.join(flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step))) "ctl_step_{}.ckpt".format(current_step)))
...@@ -397,6 +400,7 @@ class TransformerTask(object): ...@@ -397,6 +400,7 @@ class TransformerTask(object):
scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps) scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
callbacks = misc.get_callbacks(params["steps_between_evals"]) callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback) callbacks.append(scheduler_callback)
if params["enable_checkpointing"]:
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt") ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append( callbacks.append(
tf.keras.callbacks.ModelCheckpoint( tf.keras.callbacks.ModelCheckpoint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment