Commit 47e783e6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 263485329
parent dd376f53
...@@ -273,7 +273,7 @@ def _generate_synthetic_data(params): ...@@ -273,7 +273,7 @@ def _generate_synthetic_data(params):
label_value=1, label_value=1,
label_dtype=tf.int64, label_dtype=tf.int64,
) )
return dataset.batch(batch) return dataset.batch(batch, drop_remainder=True)
def train_input_fn(params): def train_input_fn(params):
......
...@@ -176,6 +176,25 @@ def define_transformer_flags(): ...@@ -176,6 +176,25 @@ def define_transformer_flags():
flags.DEFINE_string( flags.DEFINE_string(
name='mode', default='train', name='mode', default='train',
help=flags_core.help_wrap('mode: train, eval, or predict')) help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool(
name='use_ctl',
default=False,
help=flags_core.help_wrap(
'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='use_tpu',
default=False,
help=flags_core.help_wrap('Whether the model runs on TPU.'))
flags.DEFINE_bool(
name='is_tpu_pod',
default=False,
help=flags_core.help_wrap('Whether the model runs on a TPU pod.'))
flags.DEFINE_bool(
name='use_tpu_2vm_config',
default=False,
help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
...@@ -216,8 +235,6 @@ def define_transformer_flags(): ...@@ -216,8 +235,6 @@ def define_transformer_flags():
return True return True
# pylint: enable=unused-variable # pylint: enable=unused-variable
flags_core.require_cloud_storage(['data_dir', 'model_dir', 'export_dir'])
def get_callbacks(): def get_callbacks():
"""Returns common callbacks.""" """Returns common callbacks."""
......
...@@ -23,6 +23,51 @@ import tensorflow as tf ...@@ -23,6 +23,51 @@ import tensorflow as tf
K = tf.keras.backend K = tf.keras.backend
class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule."""
def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
hidden_size: An integer, the model dimension in the hidden layers.
warmup_steps: An integer, the number of steps required for linear warmup.
"""
super(LearningRateSchedule, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.hidden_size = hidden_size
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, global_step):
"""Calculate learning rate with linear warmup and rsqrt decay.
Args:
global_step: An integer, the current global step used for learning rate
calculation.
Returns:
A float, the learning rate needs to be used for current global step.
"""
with tf.name_scope('learning_rate_schedule'):
global_step = tf.cast(global_step, tf.float32)
learning_rate = self.initial_learning_rate
learning_rate *= (self.hidden_size**-0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
# Apply rsqrt decay
learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
'initial_learning_rate': self.initial_learning_rate,
'hidden_size': self.hidden_size,
'warmup_steps': self.warmup_steps,
}
class LearningRateFn(object): class LearningRateFn(object):
"""Creates learning rate function.""" """Creates learning rate function."""
......
...@@ -27,12 +27,14 @@ import tempfile ...@@ -27,12 +27,14 @@ import tempfile
from absl import app as absl_app # pylint: disable=unused-import from absl import app as absl_app # pylint: disable=unused-import
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.transformer.v2 import data_pipeline from official.transformer.v2 import data_pipeline
from official.transformer.v2 import metrics
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import optimizer from official.transformer.v2 import optimizer
from official.transformer.v2 import transformer from official.transformer.v2 import transformer
...@@ -75,8 +77,8 @@ def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file): ...@@ -75,8 +77,8 @@ def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file):
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
model, subtokenizer, bleu_source, bleu_ref) model, subtokenizer, bleu_source, bleu_ref)
tf.compat.v1.logging.info("Bleu score (uncased): %s", uncased_score) logging.info("Bleu score (uncased): %s", uncased_score)
tf.compat.v1.logging.info("Bleu score (cased): %s", cased_score) logging.info("Bleu score (cased): %s", cased_score)
return uncased_score, cased_score return uncased_score, cased_score
...@@ -88,29 +90,26 @@ class TransformerTask(object): ...@@ -88,29 +90,26 @@ class TransformerTask(object):
Args: Args:
flags_obj: Object containing parsed flag values, i.e., FLAGS. flags_obj: Object containing parsed flag values, i.e., FLAGS.
Raises:
ValueError: if not using static batch for input data on TPU.
""" """
self.flags_obj = flags_obj self.flags_obj = flags_obj
self.predict_model = None self.predict_model = None
# Add flag-defined parameters to params object # Add flag-defined parameters to params object
num_gpus = flags_core.get_num_gpus(flags_obj) num_gpus = flags_core.get_num_gpus(flags_obj)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj))
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["num_gpus"] = num_gpus params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl
params["use_tpu"] = flags_obj.use_tpu
params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch params["static_batch"] = flags_obj.static_batch
if params["use_tpu"] and not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
params["max_length"] = flags_obj.max_length params["max_length"] = flags_obj.max_length
params["num_parallel_calls"] = ( params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
...@@ -130,6 +129,27 @@ class TransformerTask(object): ...@@ -130,6 +129,27 @@ class TransformerTask(object):
"infer_float32_vars") "infer_float32_vars")
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
# TODO(b/139414143): Move the following logic to distribution utils when
# remote_eager related code can pass the copybara and being added
# completely.
if params["use_tpu"]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
flags_obj.tpu or "local")
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
self.distribution_strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver)
else:
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus)
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
def train(self): def train(self):
"""Trains the model.""" """Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True params, flags_obj, is_train = self.params, self.flags_obj, True
...@@ -142,7 +162,8 @@ class TransformerTask(object): ...@@ -142,7 +162,8 @@ class TransformerTask(object):
with self.distribution_strategy.scope(): with self.distribution_strategy.scope():
model = transformer.create_model(params, is_train) model = transformer.create_model(params, is_train)
opt = self._create_optimizer() opt = self._create_optimizer()
model.compile(opt) if not params["use_ctl"]:
model.compile(opt)
else: else:
model = transformer.create_model(params, is_train) model = transformer.create_model(params, is_train)
opt = self._create_optimizer() opt = self._create_optimizer()
...@@ -151,12 +172,71 @@ class TransformerTask(object): ...@@ -151,12 +172,71 @@ class TransformerTask(object):
model.summary() model.summary()
train_ds = data_pipeline.train_input_fn(params) train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn if params["use_tpu"]:
train_ds = train_ds.map(map_data_fn, if params["is_tpu_pod"]:
num_parallel_calls=params["num_parallel_calls"]) train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda: data_pipeline.train_input_fn(params)))
else:
train_ds = (
self.distribution_strategy.experimental_distribute_dataset(train_ds)
)
else:
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# TODO(b/139418525): Refactor the custom training loop logic.
@tf.function
def train_steps(iterator, steps):
"""Training steps function for TPU runs.
Args:
iterator: The input iterator of the training dataset.
steps: An integer, the number of training steps.
Returns:
A float, the loss value.
"""
def _step_fn(inputs):
"""Per-replica step function."""
inputs, targets = inputs
with tf.GradientTape() as tape:
logits = model([inputs, targets], training=True)
loss = metrics.transformer_loss(logits, targets,
params["label_smoothing"],
params["vocab_size"])
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
trainable_vars = model.trainable_variables
grads = tape.gradient(scaled_loss, trainable_vars)
opt.apply_gradients(zip(grads, trainable_vars))
return scaled_loss
loss = tf.constant(0.0)
for _ in tf.range(steps):
per_replica_losses = self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),))
# Recover the mean of the loss across replicas. E.g.,
# loss_0 / 2 + loss_1 / 2 = (loss_0 + loss_1) / 2 for two replicas.
loss = self.distribution_strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return loss
if params["use_tpu"]:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
if flags_obj.train_steps < flags_obj.steps_between_evals: if flags_obj.train_steps < flags_obj.steps_between_evals:
flags_obj.steps_between_evals = flags_obj.train_steps flags_obj.steps_between_evals = flags_obj.train_steps
iterations = flags_obj.train_steps // flags_obj.steps_between_evals iterations = flags_obj.train_steps // flags_obj.steps_between_evals
...@@ -165,28 +245,51 @@ class TransformerTask(object): ...@@ -165,28 +245,51 @@ class TransformerTask(object):
cased_score_history, uncased_score_history = [], [] cased_score_history, uncased_score_history = [], []
for i in range(1, iterations + 1): for i in range(1, iterations + 1):
print("Start train iteration:{}/{}".format(i, iterations)) print("Start train iteration:{}/{}".format(i, iterations))
history = model.fit( history = None
train_ds, if params["use_ctl"]:
initial_epoch=i-1, if not params["use_tpu"]:
epochs=i, raise NotImplementedError(
steps_per_epoch=flags_obj.steps_between_evals, "Custom training loop on GPUs is not implemented.")
callbacks=callbacks, train_steps_per_eval = tf.convert_to_tensor(
# If TimeHistory is enabled, progress bar would be messy. Increase the flags_obj.steps_between_evals, dtype=tf.int32)
# verbose level to get rid of it. train_loss = train_steps(iter(train_ds),
verbose=(2 if flags_obj.enable_time_history else 1)) train_steps_per_eval).numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals,
flags_obj.train_steps, train_loss)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if params["use_tpu"]:
raise NotImplementedError(
"Keras model.fit on TPUs is not implemented.")
history = model.fit(
train_ds,
initial_epoch=i - 1,
epochs=i,
steps_per_epoch=flags_obj.steps_between_evals,
callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
logging.info("Train history: {}".format(history.history))
print("End train iteration:{}/{} global step:{}".format( print("End train iteration:{}/{} global step:{}".format(
i, i,
iterations, iterations,
i*flags_obj.steps_between_evals)) i*flags_obj.steps_between_evals))
tf.compat.v1.logging.info("Train history: {}".format(history.history))
stats = misc.build_stats(history, callbacks)
if (flags_obj.bleu_source and flags_obj.bleu_ref): if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval() uncased_score, cased_score = self.eval()
cased_score_history.append([i, cased_score]) cased_score_history.append([i, cased_score])
uncased_score_history.append([i, uncased_score]) uncased_score_history.append([i, uncased_score])
stats = misc.build_stats(history, callbacks) stats = ({
"loss": train_loss
} if history is None else misc.build_stats(history, callbacks))
if uncased_score and cased_score: if uncased_score and cased_score:
stats["bleu_uncased"] = uncased_score stats["bleu_uncased"] = uncased_score
stats["bleu_cased"] = cased_score stats["bleu_cased"] = cased_score
...@@ -242,16 +345,28 @@ class TransformerTask(object): ...@@ -242,16 +345,28 @@ class TransformerTask(object):
def _load_weights_if_possible(self, model, init_weight_path=None): def _load_weights_if_possible(self, model, init_weight_path=None):
"""Loads model weights when it is provided.""" """Loads model weights when it is provided."""
if init_weight_path: if init_weight_path:
tf.compat.v1.logging.info("Load weights: {}".format(init_weight_path)) logging.info("Load weights: {}".format(init_weight_path))
model.load_weights(init_weight_path) # TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU.
if self.flags_obj.use_tpu:
checkpoint = tf.train.Checkpoint(
model=model, optimizer=self._create_optimizer())
checkpoint.restore(init_weight_path)
else:
model.load_weights(init_weight_path)
else: else:
print("Weights not loaded from path:{}".format(init_weight_path)) print("Weights not loaded from path:{}".format(init_weight_path))
def _create_optimizer(self): def _create_optimizer(self):
"""Creates optimizer.""" """Creates optimizer."""
params = self.params params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"])
opt = tf.keras.optimizers.Adam( opt = tf.keras.optimizers.Adam(
params["learning_rate"], lr_schedule if params["use_tpu"] else params["learning_rate"],
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
...@@ -264,25 +379,35 @@ class TransformerTask(object): ...@@ -264,25 +379,35 @@ class TransformerTask(object):
def _ensure_dir(log_dir): def _ensure_dir(log_dir):
"""Makes log dir if not existed.""" """Makes log dir if not existed."""
if not os.path.exists(log_dir): if not tf.io.gfile.exists(log_dir):
os.makedirs(log_dir) tf.io.gfile.makedirs(log_dir)
def main(_): def main(_):
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj): with logger.benchmark_context(flags_obj):
task = TransformerTask(flags_obj) task = TransformerTask(flags_obj)
if flags_obj.mode == "train":
task.train() def _run_task(task):
elif flags_obj.mode == "predict": if flags_obj.mode == "train":
task.predict() task.train()
elif flags_obj.mode == "eval": elif flags_obj.mode == "predict":
task.eval() task.predict()
elif flags_obj.mode == "eval":
task.eval()
else:
raise ValueError("Invalid mode {}".format(flags_obj.mode))
if not flags_obj.use_tpu:
_run_task(task)
else: else:
raise ValueError("Invalid mode {}".format(flags_obj.mode)) primary_cpu_task = ("/job:worker"
if flags_obj.use_tpu_2vm_config is not None else "")
with tf.device(primary_cpu_task):
_run_task(task)
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
misc.define_transformer_flags() misc.define_transformer_flags()
absl_app.run(main) absl_app.run(main)
...@@ -89,6 +89,10 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -89,6 +89,10 @@ class TransformerTaskTest(tf.test.TestCase):
if context.num_gpus() >= 2: if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.') self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
if tf.test.is_built_with_cuda():
FLAGS.num_gpus = 1
else:
FLAGS.num_gpus = 0
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment