Commit eaa4003c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411096556
parent 19e3d234
......@@ -403,6 +403,9 @@ class Trainer(_AsyncTrainer):
"""See base class."""
def step_fn(inputs):
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
logs = task_train_step(
inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment