Commit dfaf525e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Only use runtime_options for training.

PiperOrigin-RevId: 363782489
parent b6d1ec22
...@@ -163,7 +163,7 @@ def get_runtime_options(config: ExperimentConfig): ...@@ -163,7 +163,7 @@ def get_runtime_options(config: ExperimentConfig):
xla_options = {} xla_options = {}
if config.runtime.tpu_enable_xla_dynamic_padder is not None: if config.runtime.tpu_enable_xla_dynamic_padder is not None:
xla_options["enable_xla_dynamic_padder"] = ( xla_options["enable_xla_dynamic_padder"] = (
config.runtime.enable_xla_dynamic_padder) config.runtime.tpu_enable_xla_dynamic_padder)
return tf.distribute.RunOptions( return tf.distribute.RunOptions(
experimental_xla_options=tf.tpu.XLAOptions(**xla_options)) experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
...@@ -205,6 +205,8 @@ class Trainer(_AsyncTrainer): ...@@ -205,6 +205,8 @@ class Trainer(_AsyncTrainer):
self._optimizer = optimizer self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._recovery = None self._recovery = None
# Runtime options are only applied to train_step.
# We use default for eval_step.
self._runtime_options = get_runtime_options(config) self._runtime_options = get_runtime_options(config)
# Creates a shadow copy of the weights to store weights moving average. # Creates a shadow copy of the weights to store weights moving average.
...@@ -407,8 +409,7 @@ class Trainer(_AsyncTrainer): ...@@ -407,8 +409,7 @@ class Trainer(_AsyncTrainer):
self._validation_loss.update_state(logs[self.task.loss]) self._validation_loss.update_state(logs[self.task.loss])
return logs return logs
distributed_outputs = self.strategy.run( distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
step_fn, args=(next(iterator),), options=self._runtime_options)
return tf.nest.map_structure(self.strategy.experimental_local_results, return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs) distributed_outputs)
......
...@@ -140,7 +140,11 @@ class RuntimeConfig(base_config.Config): ...@@ -140,7 +140,11 @@ class RuntimeConfig(base_config.Config):
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
# XLA runtime # XLA runtime params.
# XLA params are only applied to the train_step.
# These augments can improve training speed. They can also improve eval, but
# may reduce usability and users would need to make changes to code.
# Whether to enable XLA dynamic padder # Whether to enable XLA dynamic padder
# infrastructure to handle dynamic shapes inputs inside XLA. True by # infrastructure to handle dynamic shapes inputs inside XLA. True by
# default. Disabling this may cause correctness issues with dynamic shapes # default. Disabling this may cause correctness issues with dynamic shapes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment