Commit 4334a892 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use nonexperimental mixed precision API.

This replaces symbols in tf.keras.mixed_precision.experimental with the corresponding nonexperimental symbols. In some cases, passing a Policy is replaced with passing a policy name for conciseness.

Additionally, for the Shakespeare model, the loss_scale flag is removed, since supporting it with the nonexperimental API is slightly more verbose and it is recommended users use the default loss scale.

PiperOrigin-RevId: 368123944
parent 19d18c00
...@@ -29,7 +29,7 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -29,7 +29,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(BertEncoderTest, self).tearDown() super(BertEncoderTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
def test_network_creation(self): def test_network_creation(self):
hidden_size = 32 hidden_size = 32
...@@ -92,7 +92,7 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -92,7 +92,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small BertEncoder for testing. # Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=100, vocab_size=100,
......
...@@ -45,9 +45,9 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -45,9 +45,9 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
def test_layer_creation_with_mixed_precision(self): def test_layer_creation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy) vocab_size=vocab_size, embedding_width=embedding_width,
dtype="mixed_float16")
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32) input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
...@@ -83,9 +83,9 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -83,9 +83,9 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
def test_layer_invocation_with_mixed_precision(self): def test_layer_invocation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, dtype=policy) vocab_size=vocab_size, embedding_width=embedding_width,
dtype="mixed_float16")
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32) input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
...@@ -123,11 +123,10 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -123,11 +123,10 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
def test_one_hot_layer_creation_with_mixed_precision(self): def test_one_hot_layer_creation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
dtype=policy, dtype="mixed_float16",
use_one_hot=True) use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
...@@ -166,11 +165,10 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -166,11 +165,10 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
def test_one_hot_layer_invocation_with_mixed_precision(self): def test_one_hot_layer_invocation_with_mixed_precision(self):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
dtype=policy, dtype="mixed_float16",
use_one_hot=True) use_one_hot=True)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
......
...@@ -159,7 +159,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -159,7 +159,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32. # as well, so we use float32.
......
...@@ -29,7 +29,7 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -29,7 +29,7 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(TransformerEncoderBlockLayerTest, self).tearDown() super(TransformerEncoderBlockLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self, transformer_cls): def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
...@@ -180,7 +180,7 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -180,7 +180,7 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls( test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu') num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21 sequence_length = 21
......
...@@ -108,7 +108,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -108,7 +108,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._output_dense = [] self._output_dense = []
self._output_dropout = [] self._output_dropout = []
self._output_layer_norm = [] self._output_layer_norm = []
activation_policy = tf.keras.mixed_precision.experimental.global_policy() activation_policy = tf.keras.mixed_precision.global_policy()
if activation_policy.name == "mixed_bfloat16": if activation_policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32. # as well, so we use float32.
......
...@@ -29,7 +29,7 @@ class GatedFeedforwardTest(keras_parameterized.TestCase): ...@@ -29,7 +29,7 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(GatedFeedforwardTest, self).tearDown() super(GatedFeedforwardTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.parameters( @parameterized.parameters(
(True, 1, "after_residual", "float32"), (True, 1, "after_residual", "float32"),
...@@ -42,7 +42,7 @@ class GatedFeedforwardTest(keras_parameterized.TestCase): ...@@ -42,7 +42,7 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
(False, 1, "before_residual", "mixed_float16"), (False, 1, "before_residual", "mixed_float16"),
) )
def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype): def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype):
tf.keras.mixed_precision.experimental.set_policy(dtype) tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict( kwargs = dict(
intermediate_size=128, intermediate_size=128,
intermediate_activation="relu", intermediate_activation="relu",
...@@ -74,7 +74,7 @@ class GatedFeedforwardTest(keras_parameterized.TestCase): ...@@ -74,7 +74,7 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
) )
def test_layer_invocation(self, use_gate, num_blocks, dropout_position, def test_layer_invocation(self, use_gate, num_blocks, dropout_position,
dtype): dtype):
tf.keras.mixed_precision.experimental.set_policy(dtype) tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict( kwargs = dict(
intermediate_size=16, intermediate_size=16,
intermediate_activation="relu", intermediate_activation="relu",
......
...@@ -132,7 +132,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -132,7 +132,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32. # as well, so we use float32.
......
...@@ -28,10 +28,10 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase): ...@@ -28,10 +28,10 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(TransformerWithReZeroLayerTest, self).tearDown() super(TransformerWithReZeroLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_invocation_with_float16_dtype(self): def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = rezero_transformer.ReZeroTransformer( test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=10, num_attention_heads=10,
intermediate_size=2048, intermediate_size=2048,
......
...@@ -30,7 +30,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -30,7 +30,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(TransformerLayerTest, self).tearDown() super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self, transformer_cls): def test_layer_creation(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
...@@ -151,7 +151,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -151,7 +151,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
test_layer = transformer_cls( test_layer = transformer_cls(
num_attention_heads=16, num_attention_heads=16,
intermediate_size=2048, intermediate_size=2048,
......
...@@ -190,7 +190,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -190,7 +190,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32. # as well, so we use float32.
......
...@@ -83,7 +83,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -83,7 +83,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(TransformerLayerTest, self).tearDown() super(TransformerLayerTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
def test_layer_creation(self): def test_layer_creation(self):
sequence_length = 21 sequence_length = 21
...@@ -308,7 +308,7 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -308,7 +308,7 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.") self.assertTrue(call_list[0], "The passed layer class wasn't instantiated.")
def test_layer_invocation_with_float16_dtype(self): def test_layer_invocation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
sequence_length = 21 sequence_length = 21
width = 80 width = 80
......
...@@ -33,7 +33,7 @@ class AlbertEncoderTest(keras_parameterized.TestCase): ...@@ -33,7 +33,7 @@ class AlbertEncoderTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(AlbertEncoderTest, self).tearDown() super(AlbertEncoderTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name="default", expected_dtype=tf.float32), dict(testcase_name="default", expected_dtype=tf.float32),
...@@ -49,7 +49,7 @@ class AlbertEncoderTest(keras_parameterized.TestCase): ...@@ -49,7 +49,7 @@ class AlbertEncoderTest(keras_parameterized.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
if expected_dtype == tf.float16: if expected_dtype == tf.float16:
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small TransformerEncoder for testing. # Create a small TransformerEncoder for testing.
test_network = albert_encoder.AlbertEncoder(**kwargs) test_network = albert_encoder.AlbertEncoder(**kwargs)
...@@ -148,7 +148,7 @@ class AlbertEncoderTest(keras_parameterized.TestCase): ...@@ -148,7 +148,7 @@ class AlbertEncoderTest(keras_parameterized.TestCase):
self.assertLen(dict_outputs["pooled_output"], num_layers) self.assertLen(dict_outputs["pooled_output"], num_layers)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
vocab_size=100, vocab_size=100,
......
...@@ -30,7 +30,7 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -30,7 +30,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(BertEncoderTest, self).tearDown() super(BertEncoderTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
def test_network_creation(self): def test_network_creation(self):
hidden_size = 32 hidden_size = 32
...@@ -119,7 +119,7 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -119,7 +119,7 @@ class BertEncoderTest(keras_parameterized.TestCase):
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Create a small BertEncoder for testing. # Create a small BertEncoder for testing.
test_network = bert_encoder.BertEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=100, vocab_size=100,
......
...@@ -59,7 +59,7 @@ class Classification(tf.keras.Model): ...@@ -59,7 +59,7 @@ class Classification(tf.keras.Model):
if output == 'logits': if output == 'logits':
output_tensors = logits output_tensors = logits
elif output == 'predictions': elif output == 'predictions':
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == 'mixed_bfloat16': if policy.name == 'mixed_bfloat16':
# b/158514794: bf16 is not stable with post-softmax cross-entropy. # b/158514794: bf16 is not stable with post-softmax cross-entropy.
policy = tf.float32 policy = tf.float32
......
...@@ -52,7 +52,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -52,7 +52,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(EncoderScaffoldLayerClassTest, self).tearDown() super(EncoderScaffoldLayerClassTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name="only_final_output", return_all_layer_outputs=False), dict(testcase_name="only_final_output", return_all_layer_outputs=False),
...@@ -132,7 +132,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -132,7 +132,7 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self.assertTrue(hasattr(test_network, "_output_layer_norm")) self.assertTrue(hasattr(test_network, "_output_layer_norm"))
def test_network_creation_with_float16_dtype(self): def test_network_creation_with_float16_dtype(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
embedding_cfg = { embedding_cfg = {
......
...@@ -27,7 +27,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -27,7 +27,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(PackedSequenceEmbeddingTest, self).tearDown() super(PackedSequenceEmbeddingTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
@parameterized.parameters([ @parameterized.parameters([
(True, True, True), (True, True, True),
...@@ -39,7 +39,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -39,7 +39,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
use_float16): use_float16):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
if use_float16: if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
seq_length = 16 seq_length = 16
vocab_size = 100 vocab_size = 100
max_position_embeddings = 32 max_position_embeddings = 32
...@@ -99,7 +99,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -99,7 +99,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllEqual(expected_attention_mask_shape, attention_mask.shape) self.assertAllEqual(expected_attention_mask_shape, attention_mask.shape)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=100, vocab_size=100,
......
...@@ -67,10 +67,10 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -67,10 +67,10 @@ class TransformerTaskTest(tf.test.TestCase):
self.bleu_source = os.path.join(temp_dir, 'bleu_source') self.bleu_source = os.path.join(temp_dir, 'bleu_source')
self.bleu_ref = os.path.join(temp_dir, 'bleu_ref') self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
self.orig_policy = ( self.orig_policy = (
tf.compat.v2.keras.mixed_precision.experimental.global_policy()) tf.compat.v2.keras.mixed_precision.global_policy())
def tearDown(self): # pylint: disable=g-missing-super-call def tearDown(self): # pylint: disable=g-missing-super-call
tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy) tf.compat.v2.keras.mixed_precision.set_global_policy(self.orig_policy)
def _assert_exists(self, filepath): def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
......
...@@ -70,9 +70,7 @@ def run_executor(params, ...@@ -70,9 +70,7 @@ def run_executor(params,
"""Runs the object detection model on distribution strategy defined by the user.""" """Runs the object detection model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16: if params.architecture.use_bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
model_builder = model_factory.model_generator(params) model_builder = model_factory.model_generator(params)
......
...@@ -60,9 +60,7 @@ class Model(object): ...@@ -60,9 +60,7 @@ class Model(object):
self._use_bfloat16 = params.architecture.use_bfloat16 self._use_bfloat16 = params.architecture.use_bfloat16
if params.architecture.use_bfloat16: if params.architecture.use_bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
# Optimization. # Optimization.
self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer) self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment