Commit eaa4003c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411096556
parent 19e3d234
......@@ -403,7 +403,10 @@ class Trainer(_AsyncTrainer):
"""See base class."""
def step_fn(inputs):
task_train_step = self.task.train_step
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
logs = task_train_step(
inputs,
model=self.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment