Commit 4ab6f1b1 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Pass use_experimental_api=True to mixed precision functions.

The default is True, but I plan on changing it to False soon. After that, I plan on removing the argument and never using the experimental API.

PiperOrigin-RevId: 360724698
parent 067d3e41
...@@ -80,7 +80,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -80,7 +80,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16", use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale) loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
return optimizer return optimizer
......
...@@ -46,7 +46,8 @@ def main(_): ...@@ -46,7 +46,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -108,7 +108,8 @@ def run_continuous_finetune( ...@@ -108,7 +108,8 @@ def run_continuous_finetune(
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -47,7 +47,8 @@ def main(_): ...@@ -47,7 +47,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -51,7 +51,8 @@ def main(_): ...@@ -51,7 +51,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -47,7 +47,8 @@ def main(_): ...@@ -47,7 +47,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -97,7 +97,8 @@ def main(_): ...@@ -97,7 +97,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
input_partition_dims = None input_partition_dims = None
if FLAGS.mode == 'train_and_eval': if FLAGS.mode == 'train_and_eval':
......
...@@ -341,7 +341,8 @@ def train_and_eval( ...@@ -341,7 +341,8 @@ def train_and_eval(
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=train_builder.dtype == 'float16', use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params)) loss_scale=get_loss_scale(params),
use_experimental_api=True)
metrics_map = _get_metrics(one_hot) metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics] metrics = [metrics_map[metric] for metric in params.train.metrics]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment