Commit 64d412cb authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Add tpu_enable_xla_dynamic_padder to RuntimeConfig.

PiperOrigin-RevId: 363595377
parent 739785bb
...@@ -158,6 +158,16 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -158,6 +158,16 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
*args, **kwargs) *args, **kwargs)
def get_runtime_options(config: ExperimentConfig):
"""Get tf.distribute.RunOptions from config."""
xla_options = {}
if config.runtime.tpu_enable_xla_dynamic_padder is not None:
xla_options["enable_xla_dynamic_padder"] = (
config.runtime.enable_xla_dynamic_padder)
return tf.distribute.RunOptions(
experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
@gin.configurable @gin.configurable
class Trainer(_AsyncTrainer): class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models.""" """Implements the common trainer shared for TensorFlow models."""
...@@ -195,6 +205,7 @@ class Trainer(_AsyncTrainer): ...@@ -195,6 +205,7 @@ class Trainer(_AsyncTrainer):
self._optimizer = optimizer self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._recovery = None self._recovery = None
self._runtime_options = get_runtime_options(config)
# Creates a shadow copy of the weights to store weights moving average. # Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage): if isinstance(self._optimizer, optimization.ExponentialMovingAverage):
...@@ -374,7 +385,8 @@ class Trainer(_AsyncTrainer): ...@@ -374,7 +385,8 @@ class Trainer(_AsyncTrainer):
self._train_loss.update_state(logs[self.task.loss]) self._train_loss.update_state(logs[self.task.loss])
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.strategy.run(step_fn, args=(next(iterator),)) self.strategy.run(
step_fn, args=(next(iterator),), options=self._runtime_options)
def eval_begin(self): def eval_begin(self):
"""Sets up metrics.""" """Sets up metrics."""
...@@ -395,7 +407,8 @@ class Trainer(_AsyncTrainer): ...@@ -395,7 +407,8 @@ class Trainer(_AsyncTrainer):
self._validation_loss.update_state(logs[self.task.loss]) self._validation_loss.update_state(logs[self.task.loss])
return logs return logs
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),)) distributed_outputs = self.strategy.run(
step_fn, args=(next(iterator),), options=self._runtime_options)
return tf.nest.map_structure(self.strategy.experimental_local_results, return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs) distributed_outputs)
......
...@@ -140,6 +140,16 @@ class RuntimeConfig(base_config.Config): ...@@ -140,6 +140,16 @@ class RuntimeConfig(base_config.Config):
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
# XLA runtime
# Whether to enable XLA dynamic padder
# infrastructure to handle dynamic shapes inputs inside XLA. True by
# default. Disabling this may cause correctness issues with dynamic shapes
# inputs, as XLA will just assume the inputs are with padded shapes. However
# users can optionally set it to False to improve device time if masking is
# already handled in the user side.
# If None, will respect XLA default.
tpu_enable_xla_dynamic_padder: Optional[bool] = None
# Global model parallelism configurations. # Global model parallelism configurations.
num_cores_per_replica: int = 1 num_cores_per_replica: int = 1
default_shard_dim: int = -1 default_shard_dim: int = -1
......
...@@ -92,6 +92,7 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -92,6 +92,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._config = config self._config = config
self._runtime_options = trainer_lib.get_runtime_options(config)
self._task = prog_task self._task = prog_task
# Directory for non-progressive checkpoint # Directory for non-progressive checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment