Commit 1bbd359d authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368067415
parent ecbc3fd8
......@@ -48,9 +48,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -91,8 +91,7 @@ class SimCLRPretrainTask(base_task.Task):
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
loss_scale=runtime_config.loss_scale)
return optimizer
......@@ -397,8 +396,7 @@ class SimCLRFinetuneTask(base_task.Task):
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
loss_scale=runtime_config.loss_scale)
return optimizer
......
......@@ -59,9 +59,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -50,9 +50,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -46,9 +46,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -46,9 +46,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -96,9 +96,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
input_partition_dims = None
if FLAGS.mode == 'train_and_eval':
......
......@@ -227,8 +227,7 @@ def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations."""
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype,
use_experimental_api=False)
performance.set_mixed_precision_policy(dataset_builder.dtype)
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
......@@ -341,8 +340,7 @@ def train_and_eval(
optimizer = performance.configure_optimizer(
optimizer,
use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params),
use_experimental_api=True)
loss_scale=get_loss_scale(params))
metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics]
......
......@@ -99,8 +99,7 @@ def run(flags_obj):
"""
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj),
use_experimental_api=False)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
if tf.config.list_physical_devices('GPU'):
if flags_obj.tf_gpu_thread_mode:
......
......@@ -81,8 +81,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
self.optimizer,
use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128),
use_experimental_api=False)
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment