Commit 417aa073 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Change default value of use_experimental_api to False.

PiperOrigin-RevId: 367105004
parent 8ccc242c
...@@ -14,19 +14,19 @@ ...@@ -14,19 +14,19 @@
"""Functions and classes related to training performance.""" """Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf import tensorflow as tf
# TODO(b/181616568): Switch default value of `use_experimental_api` in both
# of these functions to False.
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, use_graph_rewrite=False,
loss_scale='dynamic', loss_scale='dynamic',
use_experimental_api=True): use_experimental_api=False):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
if use_float16: if use_float16:
# TODO(b/171936854): Move all methods to non-experimental api. # TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api: if use_experimental_api:
...@@ -53,8 +53,11 @@ def configure_optimizer(optimizer, ...@@ -53,8 +53,11 @@ def configure_optimizer(optimizer,
def set_mixed_precision_policy(dtype, loss_scale=None, def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=True): use_experimental_api=False):
"""Sets mix precision policy.""" """Sets mix precision policy."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
assert use_experimental_api or loss_scale is None, ( assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the ' 'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration ' 'non-experimental API is used, specify the loss scaling configuration '
......
...@@ -91,7 +91,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -91,7 +91,8 @@ class SimCLRPretrainTask(base_task.Task):
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16', use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale) loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
return optimizer return optimizer
...@@ -396,7 +397,8 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -396,7 +397,8 @@ class SimCLRFinetuneTask(base_task.Task):
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16', use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale) loss_scale=runtime_config.loss_scale,
use_experimental_api=True)
return optimizer return optimizer
......
...@@ -60,7 +60,8 @@ def main(_): ...@@ -60,7 +60,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -47,7 +47,8 @@ def main(_): ...@@ -47,7 +47,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment