"docs/vscode:/vscode.git/clone" did not exist on "248c50515d21d495bc215c42fa5cb57d593f61bd"
Commit eaa4003c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411096556
parent 19e3d234
...@@ -403,6 +403,9 @@ class Trainer(_AsyncTrainer): ...@@ -403,6 +403,9 @@ class Trainer(_AsyncTrainer):
"""See base class.""" """See base class."""
def step_fn(inputs): def step_fn(inputs):
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step task_train_step = self.task.train_step
logs = task_train_step( logs = task_train_step(
inputs, inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment