Commit dc2ae42c authored by Ran Chen's avatar Ran Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 336215069
parent 2fa75516
...@@ -73,7 +73,7 @@ def _filter_and_allreduce_gradients(grads_and_vars, ...@@ -73,7 +73,7 @@ def _filter_and_allreduce_gradients(grads_and_vars,
hints = tf.distribute.experimental.CollectiveHints( hints = tf.distribute.experimental.CollectiveHints(
bytes_per_pack=bytes_per_pack) bytes_per_pack=bytes_per_pack)
allreduced_grads = tf.distribute.get_replica_context().all_reduce( allreduced_grads = tf.distribute.get_replica_context().all_reduce(
tf.distribute.ReduceOp.SUM, grads, experimental_hints=hints) tf.distribute.ReduceOp.SUM, grads, hints)
if allreduce_precision == "float16": if allreduce_precision == "float16":
allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads] allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
return allreduced_grads, variables return allreduced_grads, variables
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment