Commit c0ac8d1c authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use nonexperimental mixed precision API.

This replaces symbols in tf.keras.mixed_precision.experimental with the corresponding nonexperimental symbols. In some cases, passing a Policy is replaced with passing a policy name for conciseness.

Additionally, for the Shakespeare model, the loss_scale flag is removed, since supporting it with the nonexperimental API is slightly more verbose and it is recommended users use the default loss scale.

PiperOrigin-RevId: 368123944
parent 08fe7f0a
......@@ -91,13 +91,12 @@ class KerasImagenetTest(tf.test.TestCase):
def setUp(self):
super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4
self.policy = \
tf.keras.mixed_precision.experimental.global_policy()
self.policy = tf.keras.mixed_precision.global_policy()
def tearDown(self):
super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
tf.keras.mixed_precision.experimental.set_policy(self.policy)
tf.keras.mixed_precision.set_global_policy(self.policy)
def get_extra_flags_dict(self, flags_key):
return self._extra_flags_dict[flags_key] + self._default_flags_dict
......
......@@ -66,13 +66,12 @@ class KerasImagenetTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4
self.policy = \
tf.keras.mixed_precision.experimental.global_policy()
self.policy = tf.keras.mixed_precision.global_policy()
def tearDown(self):
super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
tf.keras.mixed_precision.experimental.set_policy(self.policy)
tf.keras.mixed_precision.set_global_policy(self.policy)
@parameterized.parameters([
"resnet",
......
......@@ -57,7 +57,6 @@ def define_flags():
synthetic_data=False,
max_train_steps=False,
dtype=True,
loss_scale=True,
enable_xla=True)
flags_core.set_defaults(train_epochs=43,
......@@ -185,8 +184,8 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
use_cudnn=flags_obj.cudnn)
# When keras_use_ctl is False, Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer.
# Model.fit() automatically applies loss scaling so we don't need to create
# a LossScaleOptimizer.
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
......@@ -269,11 +268,7 @@ def run(flags_obj):
'shakespeare.txt')
if flags_obj.dtype == 'fp16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16',
loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16='dynamic'))
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy('mixed_float16')
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
......
......@@ -97,9 +97,8 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
tf.keras.mixed_precision.set_global_policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {}
start_time_sec = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment