Commit 1b5a4c9e authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343529801
parent 3d081c09
......@@ -228,7 +228,7 @@ def initialize(params: base_configs.ExperimentConfig,
"""Initializes backend related initializations."""
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params))
use_experimental_api=False)
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
......@@ -338,6 +338,10 @@ def train_and_eval(
base_learning_rate=learning_rate,
params=params.model.optimizer.as_dict(),
model=model)
optimizer = performance.configure_optimizer(
optimizer,
use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params))
metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics]
......
......@@ -121,8 +121,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
# Set the keras policy
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# Get the model, datasets, and compile it.
model = get_trivial_model(10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment