Commit 172bf8ff authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 329042049
parent eee5ca5f
...@@ -82,6 +82,13 @@ def define_common_bert_flags(): ...@@ -82,6 +82,13 @@ def define_common_bert_flags():
'allreduce in optimizer.apply_gradients(). If fp16 mixed ' 'allreduce in optimizer.apply_gradients(). If fp16 mixed '
'precision training is used, this also enables allreduce ' 'precision training is used, this also enables allreduce '
'gradients in fp16.') 'gradients in fp16.')
flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
'Number of bytes of a gradient pack for allreduce. '
'Should be positive integer, if set to 0, all '
'gradients are in one pack. Breaking gradient into '
'packs could enable overlap between allreduce and '
'backprop computation. This flag only takes effect '
'when explicit_allreduce is set to True.')
flags_core.define_log_steps() flags_core.define_log_steps()
......
...@@ -133,7 +133,8 @@ def run_customized_training_loop( ...@@ -133,7 +133,8 @@ def run_customized_training_loop(
explicit_allreduce=False, explicit_allreduce=False,
pre_allreduce_callbacks=None, pre_allreduce_callbacks=None,
post_allreduce_callbacks=None, post_allreduce_callbacks=None,
train_summary_interval=0): train_summary_interval=0,
allreduce_bytes_per_pack=0):
"""Run BERT pretrain model training using low-level API. """Run BERT pretrain model training using low-level API.
Arguments: Arguments:
...@@ -201,6 +202,11 @@ def run_customized_training_loop( ...@@ -201,6 +202,11 @@ def run_customized_training_loop(
when explicit_allreduce=True. when explicit_allreduce=True.
train_summary_interval: Step interval for training summaries. If the value train_summary_interval: Step interval for training summaries. If the value
is a negative number, then training summaries are not enabled. is a negative number, then training summaries are not enabled.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack. Breaking gradient into packs could enable overlap between
allreduce and backprop computation. This flag only takes effect when
explicit_allreduce is set to True.'
Returns: Returns:
Trained model. Trained model.
...@@ -332,7 +338,8 @@ def run_customized_training_loop( ...@@ -332,7 +338,8 @@ def run_customized_training_loop(
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
training_vars, training_vars,
pre_allreduce_callbacks, pre_allreduce_callbacks,
post_allreduce_callbacks) post_allreduce_callbacks,
allreduce_bytes_per_pack)
else: else:
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.experimental.LossScaleOptimizer):
......
...@@ -109,7 +109,8 @@ def run_customized_training(strategy, ...@@ -109,7 +109,8 @@ def run_customized_training(strategy,
custom_callbacks=None, custom_callbacks=None,
explicit_allreduce=False, explicit_allreduce=False,
pre_allreduce_callbacks=None, pre_allreduce_callbacks=None,
post_allreduce_callbacks=None): post_allreduce_callbacks=None,
allreduce_bytes_per_pack=0):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
...@@ -146,6 +147,7 @@ def run_customized_training(strategy, ...@@ -146,6 +147,7 @@ def run_customized_training(strategy,
explicit_allreduce=explicit_allreduce, explicit_allreduce=explicit_allreduce,
pre_allreduce_callbacks=pre_allreduce_callbacks, pre_allreduce_callbacks=pre_allreduce_callbacks,
post_allreduce_callbacks=post_allreduce_callbacks, post_allreduce_callbacks=post_allreduce_callbacks,
allreduce_bytes_per_pack=allreduce_bytes_per_pack,
train_summary_interval=train_summary_interval, train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
...@@ -165,10 +167,12 @@ def run_bert_pretrain(strategy, custom_callbacks=None): ...@@ -165,10 +167,12 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
performance.set_mixed_precision_policy(common_flags.dtype()) performance.set_mixed_precision_policy(common_flags.dtype())
# If explicit_allreduce = True, apply_gradients() no longer implicitly # Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce gradients, users manually allreduce gradient and pass the # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept # longer implicitly allreduce gradients, users manually allreduce gradient and
# before allreduce, to be consistent with original TF1 model. # pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
return run_customized_training( return run_customized_training(
strategy, strategy,
bert_config, bert_config,
...@@ -191,7 +195,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None): ...@@ -191,7 +195,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
explicit_allreduce=FLAGS.explicit_allreduce, explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[ pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback model_training_utils.clip_by_global_norm_callback
]) ],
allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def main(_): def main(_):
......
...@@ -260,10 +260,12 @@ def train_squad(strategy, ...@@ -260,10 +260,12 @@ def train_squad(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model return squad_model, core_model
# If explicit_allreduce = True, apply_gradients() no longer implicitly # Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce gradients, users manually allreduce gradient and pass the # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm is kept # longer implicitly allreduce gradients, users manually allreduce gradient and
# before allreduce, to be consistent with the original TF1 model. # pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
model_training_utils.run_customized_training_loop( model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_squad_model, model_fn=_get_squad_model,
...@@ -280,7 +282,8 @@ def train_squad(strategy, ...@@ -280,7 +282,8 @@ def train_squad(strategy,
explicit_allreduce=FLAGS.explicit_allreduce, explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[ pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback model_training_utils.clip_by_global_norm_callback
]) ],
allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib, def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
......
...@@ -48,7 +48,8 @@ def _filter_grads(grads_and_vars): ...@@ -48,7 +48,8 @@ def _filter_grads(grads_and_vars):
def _filter_and_allreduce_gradients(grads_and_vars, def _filter_and_allreduce_gradients(grads_and_vars,
allreduce_precision="float32"): allreduce_precision="float32",
bytes_per_pack=0):
"""Filter None grads and then allreduce gradients in specified precision. """Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce This utils function is used when users intent to explicitly allreduce
...@@ -59,6 +60,8 @@ def _filter_and_allreduce_gradients(grads_and_vars, ...@@ -59,6 +60,8 @@ def _filter_and_allreduce_gradients(grads_and_vars,
Arguments: Arguments:
grads_and_vars: gradients and variables pairs. grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16. allreduce_precision: Whether to allreduce gradients in float32 or float16.
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, all gradients are in one pack.
Returns: Returns:
pairs of allreduced non-None gradients and variables. pairs of allreduced non-None gradients and variables.
...@@ -67,8 +70,10 @@ def _filter_and_allreduce_gradients(grads_and_vars, ...@@ -67,8 +70,10 @@ def _filter_and_allreduce_gradients(grads_and_vars,
(grads, variables) = zip(*filtered_grads_and_vars) (grads, variables) = zip(*filtered_grads_and_vars)
if allreduce_precision == "float16": if allreduce_precision == "float16":
grads = [tf.cast(grad, "float16") for grad in grads] grads = [tf.cast(grad, "float16") for grad in grads]
hints = tf.distribute.experimental.CollectiveHints(
bytes_per_pack=bytes_per_pack)
allreduced_grads = tf.distribute.get_replica_context().all_reduce( allreduced_grads = tf.distribute.get_replica_context().all_reduce(
tf.distribute.ReduceOp.SUM, grads) tf.distribute.ReduceOp.SUM, grads, experimental_hints=hints)
if allreduce_precision == "float16": if allreduce_precision == "float16":
allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads] allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
return allreduced_grads, variables return allreduced_grads, variables
...@@ -85,7 +90,8 @@ def minimize_using_explicit_allreduce(tape, ...@@ -85,7 +90,8 @@ def minimize_using_explicit_allreduce(tape,
loss, loss,
trainable_variables, trainable_variables,
pre_allreduce_callbacks=None, pre_allreduce_callbacks=None,
post_allreduce_callbacks=None): post_allreduce_callbacks=None,
allreduce_bytes_per_pack=0):
"""Minimizes loss for one step by updating `trainable_variables`. """Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`. Minimizes loss for one step by updating `trainable_variables`.
...@@ -111,6 +117,9 @@ def minimize_using_explicit_allreduce(tape, ...@@ -111,6 +117,9 @@ def minimize_using_explicit_allreduce(tape,
returns a new gradients and model variables paris. The callback returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. are applied to variables for updates. Default is no callbacks.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack.
""" """
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.experimental.LossScaleOptimizer):
...@@ -123,7 +132,9 @@ def minimize_using_explicit_allreduce(tape, ...@@ -123,7 +132,9 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars) grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
(allreduced_scaled_grads, (allreduced_scaled_grads,
filtered_training_vars) = _filter_and_allreduce_gradients( filtered_training_vars) = _filter_and_allreduce_gradients(
grads_and_vars, allreduce_precision="float16") grads_and_vars,
allreduce_precision="float16",
bytes_per_pack=allreduce_bytes_per_pack)
allreduced_unscaled_grads = optimizer.get_unscaled_gradients( allreduced_unscaled_grads = optimizer.get_unscaled_gradients(
allreduced_scaled_grads) allreduced_scaled_grads)
grads_and_vars = zip(allreduced_unscaled_grads, filtered_training_vars) grads_and_vars = zip(allreduced_unscaled_grads, filtered_training_vars)
...@@ -135,7 +146,9 @@ def minimize_using_explicit_allreduce(tape, ...@@ -135,7 +146,9 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars) grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
(allreduced_grads, (allreduced_grads,
filtered_training_vars) = _filter_and_allreduce_gradients( filtered_training_vars) = _filter_and_allreduce_gradients(
grads_and_vars, allreduce_precision="float32") grads_and_vars,
allreduce_precision="float32",
bytes_per_pack=allreduce_bytes_per_pack)
grads_and_vars = zip(allreduced_grads, filtered_training_vars) grads_and_vars = zip(allreduced_grads, filtered_training_vars)
if post_allreduce_callbacks: if post_allreduce_callbacks:
grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars) grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment