Commit 417aa073 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Change default value of use_experimental_api to False.

PiperOrigin-RevId: 367105004
parent 8ccc242c
......@@ -14,19 +14,19 @@
"""Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf
# TODO(b/181616568): Switch default value of `use_experimental_api` in both
# of these functions to False.
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale='dynamic',
use_experimental_api=True):
use_experimental_api=False):
"""Configures optimizer object with performance options."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
if use_float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
......@@ -53,8 +53,11 @@ def configure_optimizer(optimizer,
def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=True):
use_experimental_api=False):
"""Sets mix precision policy."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration '
......
......@@ -91,7 +91,8 @@ class SimCLRPretrainTask(base_task.Task):
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale)
loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
return optimizer
......@@ -396,7 +397,8 @@ class SimCLRFinetuneTask(base_task.Task):
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale)
loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
return optimizer
......
......@@ -60,7 +60,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -47,7 +47,8 @@ def main(_):
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment