Unverified Commit 651e48e1 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix tests of mixed precision now that experimental is deprecated (#17300)

* Fix tests of mixed precision now that experimental is deprecated

* Fix mixed precision in training_args_tf.py too
parent 6d211429
......@@ -195,8 +195,7 @@ class TFTrainingArguments(TrainingArguments):
# Set to float16 at first
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy("mixed_float16")
if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
......@@ -217,8 +216,7 @@ class TFTrainingArguments(TrainingArguments):
if tpu:
# Set to bfloat16 in case of TPU
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
......
......@@ -205,7 +205,7 @@ class TFCoreModelTesterMixin:
@slow
def test_mixed_precision(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
tf.keras.mixed_precision.set_global_policy("mixed_float16")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -216,7 +216,7 @@ class TFCoreModelTesterMixin:
self.assertIsNotNone(outputs)
tf.keras.mixed_precision.experimental.set_policy("float32")
tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment