Commit 1bbd359d authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368067415
parent ecbc3fd8
...@@ -48,9 +48,7 @@ def main(_): ...@@ -48,9 +48,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -91,8 +91,7 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -91,8 +91,7 @@ class SimCLRPretrainTask(base_task.Task):
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16', use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale, loss_scale=runtime_config.loss_scale)
use_experimental_api=True)
return optimizer return optimizer
...@@ -397,8 +396,7 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -397,8 +396,7 @@ class SimCLRFinetuneTask(base_task.Task):
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16', use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale, loss_scale=runtime_config.loss_scale)
use_experimental_api=True)
return optimizer return optimizer
......
...@@ -59,9 +59,7 @@ def main(_): ...@@ -59,9 +59,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -50,9 +50,7 @@ def main(_): ...@@ -50,9 +50,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -46,9 +46,7 @@ def main(_): ...@@ -46,9 +46,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -46,9 +46,7 @@ def main(_): ...@@ -46,9 +46,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -96,9 +96,7 @@ def main(_): ...@@ -96,9 +96,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
input_partition_dims = None input_partition_dims = None
if FLAGS.mode == 'train_and_eval': if FLAGS.mode == 'train_and_eval':
......
...@@ -227,8 +227,7 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -227,8 +227,7 @@ def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder): dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations.""" """Initializes backend related initializations."""
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla) keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype, performance.set_mixed_precision_policy(dataset_builder.dtype)
use_experimental_api=False)
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first' data_format = 'channels_first'
else: else:
...@@ -341,8 +340,7 @@ def train_and_eval( ...@@ -341,8 +340,7 @@ def train_and_eval(
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=train_builder.dtype == 'float16', use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params), loss_scale=get_loss_scale(params))
use_experimental_api=True)
metrics_map = _get_metrics(one_hot) metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics] metrics = [metrics_map[metric] for metric in params.train.metrics]
......
...@@ -99,8 +99,7 @@ def run(flags_obj): ...@@ -99,8 +99,7 @@ def run(flags_obj):
""" """
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj), performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
use_experimental_api=False)
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
......
...@@ -81,8 +81,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -81,8 +81,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
self.optimizer, self.optimizer,
use_float16=self.dtype == tf.float16, use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite, use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128), loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
use_experimental_api=False)
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment