Commit 5a5e0b7e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

enable XLA around model training step via jit_compile

PiperOrigin-RevId: 366330392
parent 7687b1d3
...@@ -379,7 +379,11 @@ class Trainer(_AsyncTrainer): ...@@ -379,7 +379,11 @@ class Trainer(_AsyncTrainer):
"""See base class.""" """See base class."""
def step_fn(inputs): def step_fn(inputs):
logs = self.task.train_step( if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
logs = task_train_step(
inputs, inputs,
model=self.model, model=self.model,
optimizer=self.optimizer, optimizer=self.optimizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment