Commit 184c5586 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Minor Cleanup] Move clip_by_global_norm_callback to model_training_utils

PiperOrigin-RevId: 328888268
parent 9ac54b65
......@@ -121,10 +121,3 @@ def use_graph_rewrite():
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
def clip_by_global_norm_callback(grads_and_vars):
grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(clipped_grads, variables)
......@@ -75,6 +75,13 @@ def _float_metric_value(metric):
return metric.result().numpy().astype(float)
def clip_by_global_norm_callback(grads_and_vars):
"""Performs gradient clipping."""
grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(clipped_grads, variables)
def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
......
......@@ -189,7 +189,9 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
FLAGS.train_summary_interval,
custom_callbacks=custom_callbacks,
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[common_flags.clip_by_global_norm_callback])
pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback
])
def main(_):
......
......@@ -278,7 +278,9 @@ def train_squad(strategy,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks,
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[common_flags.clip_by_global_norm_callback])
pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback
])
def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment