Commit a90e36a4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 371640518
parent 8454dedc
...@@ -65,10 +65,10 @@ def _filter_and_allreduce_gradients(grads_and_vars, ...@@ -65,10 +65,10 @@ def _filter_and_allreduce_gradients(grads_and_vars,
(grads, variables) = zip(*filtered_grads_and_vars) (grads, variables) = zip(*filtered_grads_and_vars)
if allreduce_precision == "float16": if allreduce_precision == "float16":
grads = [tf.cast(grad, "float16") for grad in grads] grads = [tf.cast(grad, "float16") for grad in grads]
hints = tf.distribute.experimental.CollectiveHints( hints = tf.distribute.experimental.CommunicationOptions(
bytes_per_pack=bytes_per_pack) bytes_per_pack=bytes_per_pack)
allreduced_grads = tf.distribute.get_replica_context().all_reduce( allreduced_grads = tf.distribute.get_strategy( # pylint: disable=protected-access
tf.distribute.ReduceOp.SUM, grads, hints) ).extended._replica_ctx_all_reduce(tf.distribute.ReduceOp.SUM, grads, hints)
if allreduce_precision == "float16": if allreduce_precision == "float16":
allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads] allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
return allreduced_grads, variables return allreduced_grads, variables
......
...@@ -97,8 +97,7 @@ def run(flags_obj): ...@@ -97,8 +97,7 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config( keras_utils.set_session_config()
enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
......
...@@ -167,7 +167,8 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -167,7 +167,8 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
tape, self.optimizer, loss, self.model.trainable_variables) tape, self.optimizer, loss, self.model.trainable_variables)
self.train_loss.update_state(loss) self.train_loss.update_state(loss)
self.train_accuracy.update_state(labels, logits) self.train_accuracy.update_state(labels, logits)
if self.flags_obj.enable_xla:
step_fn = tf.function(step_fn, jit_compile=True)
self.strategy.run(step_fn, args=(next(iterator),)) self.strategy.run(step_fn, args=(next(iterator),))
def train_loop_end(self): def train_loop_end(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment