Commit 124e3918 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Add tpu_enable_xla_dynamic_padder to RuntimeConfig.

PiperOrigin-RevId: 363595377
parent 5e5f12f3
......@@ -158,6 +158,16 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
*args, **kwargs)
def get_runtime_options(config: ExperimentConfig):
"""Get tf.distribute.RunOptions from config."""
xla_options = {}
if config.runtime.tpu_enable_xla_dynamic_padder is not None:
xla_options["enable_xla_dynamic_padder"] = (
config.runtime.enable_xla_dynamic_padder)
return tf.distribute.RunOptions(
experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
@gin.configurable
class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
......@@ -195,6 +205,7 @@ class Trainer(_AsyncTrainer):
self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
self._recovery = None
self._runtime_options = get_runtime_options(config)
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage):
......@@ -374,7 +385,8 @@ class Trainer(_AsyncTrainer):
self._train_loss.update_state(logs[self.task.loss])
self.global_step.assign_add(1)
self.strategy.run(step_fn, args=(next(iterator),))
self.strategy.run(
step_fn, args=(next(iterator),), options=self._runtime_options)
def eval_begin(self):
"""Sets up metrics."""
......@@ -395,7 +407,8 @@ class Trainer(_AsyncTrainer):
self._validation_loss.update_state(logs[self.task.loss])
return logs
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
distributed_outputs = self.strategy.run(
step_fn, args=(next(iterator),), options=self._runtime_options)
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
......
......@@ -140,6 +140,16 @@ class RuntimeConfig(base_config.Config):
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# XLA runtime
# Whether to enable XLA dynamic padder
# infrastructure to handle dynamic shapes inputs inside XLA. True by
# default. Disabling this may cause correctness issues with dynamic shapes
# inputs, as XLA will just assume the inputs are with padded shapes. However
# users can optionally set it to False to improve device time if masking is
# already handled in the user side.
# If None, will respect XLA default.
tpu_enable_xla_dynamic_padder: Optional[bool] = None
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
......
......@@ -92,6 +92,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._config = config
self._runtime_options = trainer_lib.get_runtime_options(config)
self._task = prog_task
# Directory for non-progressive checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment