Commit b1d9ac5b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Consolidation & readability.

PiperOrigin-RevId: 263863438
parent 3a9e9edb
...@@ -181,10 +181,6 @@ def define_transformer_flags(): ...@@ -181,10 +181,6 @@ def define_transformer_flags():
default=False, default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether the model runs with custom training loop.')) 'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='use_tpu',
default=False,
help=flags_core.help_wrap('Whether the model runs on TPU.'))
flags.DEFINE_bool( flags.DEFINE_bool(
name='is_tpu_pod', name='is_tpu_pod',
default=False, default=False,
......
...@@ -30,6 +30,8 @@ from absl import flags ...@@ -30,6 +30,8 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
...@@ -103,13 +105,10 @@ class TransformerTask(object): ...@@ -103,13 +105,10 @@ class TransformerTask(object):
params["num_gpus"] = num_gpus params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl params["use_ctl"] = flags_obj.use_ctl
params["use_tpu"] = flags_obj.use_tpu
params["is_tpu_pod"] = flags_obj.is_tpu_pod params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch params["static_batch"] = flags_obj.static_batch
if params["use_tpu"] and not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
params["max_length"] = flags_obj.max_length params["max_length"] = flags_obj.max_length
params["num_parallel_calls"] = ( params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
...@@ -129,19 +128,14 @@ class TransformerTask(object): ...@@ -129,19 +128,14 @@ class TransformerTask(object):
"infer_float32_vars") "infer_float32_vars")
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
# TODO(b/139414143): Move the following logic to distribution utils when self.distribution_strategy = distribution_utils.get_distribution_strategy(
# remote_eager related code can pass the copybara and being added distribution_strategy=flags_obj.distribution_strategy,
# completely. num_gpus=num_gpus,
if params["use_tpu"]: tpu_address=flags_obj.tpu or "")
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( if self.use_tpu:
flags_obj.tpu or "local") if not params["static_batch"]:
tf.tpu.experimental.initialize_tpu_system(cluster_resolver) raise ValueError("TPU requires static batch for input data.")
self.distribution_strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver)
else: else:
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus)
print("Running transformer with num_gpus =", num_gpus) print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy: if self.distribution_strategy:
...@@ -150,29 +144,35 @@ class TransformerTask(object): ...@@ -150,29 +144,35 @@ class TransformerTask(object):
else: else:
print("Not using any distribution strategy.") print("Not using any distribution strategy.")
@property
def use_tpu(self):
if self.distribution_strategy:
return isinstance(self.distribution_strategy,
tf.distribute.experimental.TPUStrategy)
return False
def train(self): def train(self):
"""Trains the model.""" """Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True params = self.params
flags_obj = self.flags_obj
# Sets config options. # Sets config options.
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
if self.distribution_strategy: with distribution_utils.get_strategy_scope(self.distribution_strategy):
with self.distribution_strategy.scope(): model = transformer.create_model(params, is_train=True)
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
if not params["use_ctl"]:
model.compile(opt)
else:
model = transformer.create_model(params, is_train)
opt = self._create_optimizer() opt = self._create_optimizer()
model.compile(opt) if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
else:
model.compile(opt)
model.summary() model.summary()
train_ds = data_pipeline.train_input_fn(params) train_ds = data_pipeline.train_input_fn(params)
if params["use_tpu"]: if self.use_tpu:
if params["is_tpu_pod"]: if params["is_tpu_pod"]:
train_ds = ( train_ds = (
self.distribution_strategy self.distribution_strategy
...@@ -214,23 +214,20 @@ class TransformerTask(object): ...@@ -214,23 +214,20 @@ class TransformerTask(object):
# of the replicas for backprop. # of the replicas for backprop.
scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
trainable_vars = model.trainable_variables # De-dupes variables due to keras tracking issues.
grads = tape.gradient(scaled_loss, trainable_vars) tvars = list(
opt.apply_gradients(zip(grads, trainable_vars)) object_identity.ObjectIdentitySet(model.trainable_variables))
return scaled_loss grads = tape.gradient(scaled_loss, tvars)
opt.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
loss = tf.constant(0.0)
for _ in tf.range(steps): for _ in tf.range(steps):
per_replica_losses = self.distribution_strategy.experimental_run_v2( train_loss_metric.reset_states()
self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),)) _step_fn, args=(next(iterator),))
# Recover the mean of the loss across replicas. E.g.,
# loss_0 / 2 + loss_1 / 2 = (loss_0 + loss_1) / 2 for two replicas.
loss = self.distribution_strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return loss
if params["use_tpu"]: if self.use_tpu:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt) checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir) latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint: if latest_checkpoint:
...@@ -247,23 +244,26 @@ class TransformerTask(object): ...@@ -247,23 +244,26 @@ class TransformerTask(object):
print("Start train iteration:{}/{}".format(i, iterations)) print("Start train iteration:{}/{}".format(i, iterations))
history = None history = None
if params["use_ctl"]: if params["use_ctl"]:
if not params["use_tpu"]: if not self.use_tpu:
raise NotImplementedError( raise NotImplementedError(
"Custom training loop on GPUs is not implemented.") "Custom training loop on GPUs is not implemented.")
train_steps_per_eval = tf.convert_to_tensor( train_steps_per_eval = tf.convert_to_tensor(
flags_obj.steps_between_evals, dtype=tf.int32) flags_obj.steps_between_evals, dtype=tf.int32)
train_loss = train_steps(iter(train_ds),
train_steps_per_eval).numpy().astype(float) # Runs training steps.
train_steps(iter(train_ds), train_steps_per_eval)
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s", logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, i * flags_obj.steps_between_evals, flags_obj.train_steps,
flags_obj.train_steps, train_loss) train_loss)
checkpoint_name = checkpoint.save( checkpoint_name = checkpoint.save(
os.path.join( os.path.join(
flags_obj.model_dir, flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals))) "ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals)))
logging.info("Saved checkpoint to %s", checkpoint_name) logging.info("Saved checkpoint to %s", checkpoint_name)
else: else:
if params["use_tpu"]: if self.use_tpu:
raise NotImplementedError( raise NotImplementedError(
"Keras model.fit on TPUs is not implemented.") "Keras model.fit on TPUs is not implemented.")
history = model.fit( history = model.fit(
...@@ -312,10 +312,11 @@ class TransformerTask(object): ...@@ -312,10 +312,11 @@ class TransformerTask(object):
def predict(self): def predict(self):
"""Predicts result from the model.""" """Predicts result from the model."""
params, flags_obj, is_train = self.params, self.flags_obj, False params = self.params
flags_obj = self.flags_obj
with tf.name_scope("model"): with tf.name_scope("model"):
model = transformer.create_model(params, is_train) model = transformer.create_model(params, is_train=False)
self._load_weights_if_possible( self._load_weights_if_possible(
model, tf.train.latest_checkpoint(self.flags_obj.model_dir)) model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
model.summary() model.summary()
...@@ -348,7 +349,7 @@ class TransformerTask(object): ...@@ -348,7 +349,7 @@ class TransformerTask(object):
logging.info("Load weights: {}".format(init_weight_path)) logging.info("Load weights: {}".format(init_weight_path))
# TODO(b/139414977): Having the same variable restoring method for both # TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU. # TPU and GPU.
if self.flags_obj.use_tpu: if self.use_tpu:
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
model=model, optimizer=self._create_optimizer()) model=model, optimizer=self._create_optimizer())
checkpoint.restore(init_weight_path) checkpoint.restore(init_weight_path)
...@@ -366,7 +367,7 @@ class TransformerTask(object): ...@@ -366,7 +367,7 @@ class TransformerTask(object):
params["learning_rate"], params["hidden_size"], params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"]) params["learning_rate_warmup_steps"])
opt = tf.keras.optimizers.Adam( opt = tf.keras.optimizers.Adam(
lr_schedule if params["use_tpu"] else params["learning_rate"], lr_schedule if self.use_tpu else params["learning_rate"],
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
...@@ -398,7 +399,7 @@ def main(_): ...@@ -398,7 +399,7 @@ def main(_):
else: else:
raise ValueError("Invalid mode {}".format(flags_obj.mode)) raise ValueError("Invalid mode {}".format(flags_obj.mode))
if not flags_obj.use_tpu: if not flags_obj.distribution_strategy != "tpu":
_run_task(task) _run_task(task)
else: else:
primary_cpu_task = ("/job:worker" primary_cpu_task = ("/job:worker"
......
...@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default",
return None return None
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
if not tpu_address: # When tpu_address is an empty string, we communicate with local TPUs.
raise ValueError("`tpu_address` must be specified when using "
"TPUStrategy.")
# Initialize TPU System. # Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address) cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver) return tf.distribute.experimental.TPUStrategy(cluster_resolver)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment