Commit 7bb5ab6d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 320272300
parent c96e5d5b
...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark): ...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS) params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params) strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {} stats = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment