Commit 5b9f6e2f authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Pass use_experimental_api=True to mixed precision functions.

The default is True, but I plan on changing it to False soon. After that, I plan on removing the argument and never using the experimental API.

PiperOrigin-RevId: 360724698
parent 4c2f424f
......@@ -80,7 +80,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
return optimizer
......
......@@ -46,7 +46,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -108,7 +108,8 @@ def run_continuous_finetune(
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -47,7 +47,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -51,7 +51,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -47,7 +47,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -97,7 +97,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
input_partition_dims = None
if FLAGS.mode == 'train_and_eval':
......
......@@ -341,7 +341,8 @@ def train_and_eval(
optimizer = performance.configure_optimizer(
optimizer,
use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params))
loss_scale=get_loss_scale(params),
use_experimental_api=True)
metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment