Commit eaa4003c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411096556
parent 19e3d234
...@@ -403,7 +403,10 @@ class Trainer(_AsyncTrainer): ...@@ -403,7 +403,10 @@ class Trainer(_AsyncTrainer):
"""See base class.""" """See base class."""
def step_fn(inputs): def step_fn(inputs):
task_train_step = self.task.train_step if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
logs = task_train_step( logs = task_train_step(
inputs, inputs,
model=self.model, model=self.model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment