Commit 50efd367 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343529801
parent cc370af9
...@@ -228,7 +228,7 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -228,7 +228,7 @@ def initialize(params: base_configs.ExperimentConfig,
"""Initializes backend related initializations.""" """Initializes backend related initializations."""
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla) keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype, performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params)) use_experimental_api=False)
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first' data_format = 'channels_first'
else: else:
...@@ -338,6 +338,10 @@ def train_and_eval( ...@@ -338,6 +338,10 @@ def train_and_eval(
base_learning_rate=learning_rate, base_learning_rate=learning_rate,
params=params.model.optimizer.as_dict(), params=params.model.optimizer.as_dict(),
model=model) model=model)
optimizer = performance.configure_optimizer(
optimizer,
use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params))
metrics_map = _get_metrics(one_hot) metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics] metrics = [metrics_map[metric] for metric in params.train.metrics]
......
...@@ -121,8 +121,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -121,8 +121,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def test_resume_from_checkpoint(self): def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint.""" """Tests functionality for resuming from checkpoint."""
# Set the keras policy # Set the keras policy
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
# Get the model, datasets, and compile it. # Get the model, datasets, and compile it.
model = get_trivial_model(10) model = get_trivial_model(10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment