Commit a90e36a4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 371640518
parent 8454dedc
......@@ -65,10 +65,10 @@ def _filter_and_allreduce_gradients(grads_and_vars,
(grads, variables) = zip(*filtered_grads_and_vars)
if allreduce_precision == "float16":
grads = [tf.cast(grad, "float16") for grad in grads]
hints = tf.distribute.experimental.CollectiveHints(
hints = tf.distribute.experimental.CommunicationOptions(
bytes_per_pack=bytes_per_pack)
allreduced_grads = tf.distribute.get_replica_context().all_reduce(
tf.distribute.ReduceOp.SUM, grads, hints)
allreduced_grads = tf.distribute.get_strategy( # pylint: disable=protected-access
).extended._replica_ctx_all_reduce(tf.distribute.ReduceOp.SUM, grads, hints)
if allreduce_precision == "float16":
allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
return allreduced_grads, variables
......
......@@ -97,8 +97,7 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
keras_utils.set_session_config()
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
if tf.config.list_physical_devices('GPU'):
......
......@@ -167,7 +167,8 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
tape, self.optimizer, loss, self.model.trainable_variables)
self.train_loss.update_state(loss)
self.train_accuracy.update_state(labels, logits)
if self.flags_obj.enable_xla:
step_fn = tf.function(step_fn, jit_compile=True)
self.strategy.run(step_fn, args=(next(iterator),))
def train_loop_end(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment