Commit 68104ce3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 284805626
parent f7fd59b8
...@@ -44,7 +44,6 @@ from official.utils.logs import logger ...@@ -44,7 +44,6 @@ from official.utils.logs import logger
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
INF = int(1e9) INF = int(1e9)
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
_SINGLE_SAMPLE = 1 _SINGLE_SAMPLE = 1
...@@ -158,6 +157,7 @@ class TransformerTask(object): ...@@ -158,6 +157,7 @@ class TransformerTask(object):
params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
params["repeat_dataset"] = None params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_tensorboard"] = flags_obj.enable_tensorboard
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals params["steps_between_evals"] = flags_obj.steps_between_evals
...@@ -183,8 +183,8 @@ class TransformerTask(object): ...@@ -183,8 +183,8 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created? # like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing # We should have a better way in the tf.keras.mixed_precision API of doing
# this. # this.
loss_scale = flags_core.get_loss_scale(flags_obj, loss_scale = flags_core.get_loss_scale(
default_for_fp16="dynamic") flags_obj, default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale) "mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
...@@ -206,8 +206,7 @@ class TransformerTask(object): ...@@ -206,8 +206,7 @@ class TransformerTask(object):
params = self.params params = self.params
flags_obj = self.flags_obj flags_obj = self.flags_obj
# Sets config options. # Sets config options.
keras_utils.set_session_config( keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy): with distribution_utils.get_strategy_scope(self.distribution_strategy):
...@@ -225,6 +224,14 @@ class TransformerTask(object): ...@@ -225,6 +224,14 @@ class TransformerTask(object):
if params["use_ctl"]: if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean( train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32) "training_loss", dtype=tf.float32)
if params["enable_tensorboard"]:
summary_writer = tf.compat.v2.summary.create_file_writer(
flags_obj.model_dir)
else:
summary_writer = tf.compat.v2.summary.create_noop_writer()
train_metrics = [train_loss_metric]
if params["enable_metrics_in_training"]:
train_metrics = train_metrics + model.metrics
else: else:
model.compile(opt) model.compile(opt)
...@@ -303,17 +310,23 @@ class TransformerTask(object): ...@@ -303,17 +310,23 @@ class TransformerTask(object):
raise NotImplementedError( raise NotImplementedError(
"Custom training loop on GPUs is not implemented.") "Custom training loop on GPUs is not implemented.")
# Runs training steps. # Runs training steps.
train_steps(train_ds_iterator, with summary_writer.as_default():
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32)) train_steps(
current_step += train_steps_per_eval train_ds_iterator,
train_loss = train_loss_metric.result().numpy().astype(float) tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
logging.info("Train Step: %d/%d / loss = %s", current_step += train_steps_per_eval
current_step, flags_obj.train_steps, train_loss) train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s", current_step,
flags_obj.train_steps, train_loss)
if params["enable_tensorboard"]:
for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step)
checkpoint_name = checkpoint.save( checkpoint_name = checkpoint.save(
os.path.join( os.path.join(flags_obj.model_dir,
flags_obj.model_dir, "ctl_step_{}.ckpt".format(current_step)))
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name) logging.info("Saved checkpoint to %s", checkpoint_name)
else: else:
if self.use_tpu: if self.use_tpu:
...@@ -391,8 +404,9 @@ class TransformerTask(object): ...@@ -391,8 +404,9 @@ class TransformerTask(object):
callbacks = misc.get_callbacks(params["steps_between_evals"]) callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback) callbacks.append(scheduler_callback)
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt") ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, callbacks.append(
save_weights_only=True)) tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
return callbacks return callbacks
def _load_weights_if_possible(self, model, init_weight_path=None): def _load_weights_if_possible(self, model, init_weight_path=None):
...@@ -426,8 +440,9 @@ class TransformerTask(object): ...@@ -426,8 +440,9 @@ class TransformerTask(object):
if params["dtype"] == tf.float16: if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer( opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj, opt,
default_for_fp16="dynamic")) loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite": if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment