"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "4179782d4571ea3366706fff037fa952d8a95fe5"
Commit 11ccb99e authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Temporarily disable explicit allreduce in BERT SQuAD

In BERT SQuAD, disable explicit allreduce for now to keep the original clip_by_global_norm math. With explicit allreduce, the gradients before allreduce are scaled so even if we move clip_by_global_norm before allreduce (as in TF1 and pre-TF 2.2) it will operate on scaled gradients, the math will be changed. So with explicit allreduce, it is better to move clip_by_global_norm to after allreduce.

PiperOrigin-RevId: 299278082
parent f8777524
...@@ -150,7 +150,9 @@ def run_customized_training_loop( ...@@ -150,7 +150,9 @@ def run_customized_training_loop(
and model variables pairs as input, manipulate them, and returns a new and model variables pairs as input, manipulate them, and returns a new
gradients and model variables paris. The callback functions will be gradients and model variables paris. The callback functions will be
invoked in the list order and before gradients are allreduced. invoked in the list order and before gradients are allreduced.
Default is no callbacks. Only used when explicit_allreduce=True. With mixed precision training, the pre_allreduce_allbacks will be
applied on scaled_gradients. Default is no callbacks.
Only used when explicit_allreduce=True.
post_allreduce_callbacks: A list of callback functions that takes post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback returns a new gradients and model variables paris. The callback
......
...@@ -269,11 +269,10 @@ def train_squad(strategy, ...@@ -269,11 +269,10 @@ def train_squad(strategy,
loss_factor=1.0 / loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
# when all_reduce_sum_gradients = False, apply_gradients() no longer # If explicit_allreduce = True, apply_gradients() no longer implicitly
# implicitly allreduce gradients, users manually allreduce gradient and # allreduce gradients, users manually allreduce gradient and pass the
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
# will be moved to before users' manual allreduce to keep the math # applied to allreduced gradients.
# unchanged.
def clip_by_global_norm_callback(grads_and_vars): def clip_by_global_norm_callback(grads_and_vars):
grads, variables = zip(*grads_and_vars) grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
...@@ -291,8 +290,8 @@ def train_squad(strategy, ...@@ -291,8 +290,8 @@ def train_squad(strategy,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
explicit_allreduce=True, explicit_allreduce=False,
pre_allreduce_callbacks=[clip_by_global_norm_callback]) post_allreduce_callbacks=[clip_by_global_norm_callback])
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
......
...@@ -104,7 +104,8 @@ def minimize_using_explicit_allreduce(tape, ...@@ -104,7 +104,8 @@ def minimize_using_explicit_allreduce(tape,
and model variables pairs as input, manipulate them, and returns a new and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced. invoked in the list order and before gradients are allreduced.
Default is no callbacks. With mixed precision training, the pre_allreduce_allbacks will be
applied on scaled_gradients. Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback returns a new gradients and model variables paris. The callback
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment