"src/vscode:/vscode.git/clone" did not exist on "8ccc76ab3760cdb1ab60c7a344e16f118bb58adc"
Commit c0ac8d1c authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use nonexperimental mixed precision API.

This replaces symbols in tf.keras.mixed_precision.experimental with the corresponding nonexperimental symbols. In some cases, passing a Policy is replaced with passing a policy name for conciseness.

Additionally, for the Shakespeare model, the loss_scale flag is removed, since supporting it with the nonexperimental API is slightly more verbose and it is recommended users use the default loss scale.

PiperOrigin-RevId: 368123944
parent 08fe7f0a
...@@ -91,13 +91,12 @@ class KerasImagenetTest(tf.test.TestCase): ...@@ -91,13 +91,12 @@ class KerasImagenetTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(KerasImagenetTest, self).setUp() super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4 imagenet_preprocessing.NUM_IMAGES["validation"] = 4
self.policy = \ self.policy = tf.keras.mixed_precision.global_policy()
tf.keras.mixed_precision.experimental.global_policy()
def tearDown(self): def tearDown(self):
super(KerasImagenetTest, self).tearDown() super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
tf.keras.mixed_precision.experimental.set_policy(self.policy) tf.keras.mixed_precision.set_global_policy(self.policy)
def get_extra_flags_dict(self, flags_key): def get_extra_flags_dict(self, flags_key):
return self._extra_flags_dict[flags_key] + self._default_flags_dict return self._extra_flags_dict[flags_key] + self._default_flags_dict
......
...@@ -66,13 +66,12 @@ class KerasImagenetTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,13 +66,12 @@ class KerasImagenetTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(KerasImagenetTest, self).setUp() super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4 imagenet_preprocessing.NUM_IMAGES["validation"] = 4
self.policy = \ self.policy = tf.keras.mixed_precision.global_policy()
tf.keras.mixed_precision.experimental.global_policy()
def tearDown(self): def tearDown(self):
super(KerasImagenetTest, self).tearDown() super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir()) tf.io.gfile.rmtree(self.get_temp_dir())
tf.keras.mixed_precision.experimental.set_policy(self.policy) tf.keras.mixed_precision.set_global_policy(self.policy)
@parameterized.parameters([ @parameterized.parameters([
"resnet", "resnet",
......
...@@ -57,7 +57,6 @@ def define_flags(): ...@@ -57,7 +57,6 @@ def define_flags():
synthetic_data=False, synthetic_data=False,
max_train_steps=False, max_train_steps=False,
dtype=True, dtype=True,
loss_scale=True,
enable_xla=True) enable_xla=True)
flags_core.set_defaults(train_epochs=43, flags_core.set_defaults(train_epochs=43,
...@@ -185,8 +184,8 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None): ...@@ -185,8 +184,8 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size, model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
use_cudnn=flags_obj.cudnn) use_cudnn=flags_obj.cudnn)
# When keras_use_ctl is False, Model.fit() automatically applies # Model.fit() automatically applies loss scaling so we don't need to create
# loss scaling so we don't need to create a LossScaleOptimizer. # a LossScaleOptimizer.
model.compile( model.compile(
optimizer=tf.keras.optimizers.Adam(), optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(), loss=tf.keras.losses.CategoricalCrossentropy(),
...@@ -269,11 +268,7 @@ def run(flags_obj): ...@@ -269,11 +268,7 @@ def run(flags_obj):
'shakespeare.txt') 'shakespeare.txt')
if flags_obj.dtype == 'fp16': if flags_obj.dtype == 'fp16':
policy = tf.keras.mixed_precision.experimental.Policy( tf.keras.mixed_precision.set_global_policy('mixed_float16')
'mixed_float16',
loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16='dynamic'))
tf.keras.mixed_precision.experimental.set_policy(policy)
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
......
...@@ -97,9 +97,8 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark): ...@@ -97,9 +97,8 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
input_dtype = params.dtype input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16': if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy( tf.keras.mixed_precision.set_global_policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16') 'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {} stats = {}
start_time_sec = time.time() start_time_sec = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment