Commit d53aa621 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

added non xla fp16 test

parent 36ef0b7a
......@@ -23,7 +23,6 @@ import time
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.recommendation import ncf_common
......@@ -262,6 +261,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
"""1 GPU using CTL with eager and distribution strategy."""
self._setup()
......
......@@ -261,8 +261,11 @@ def run_ncf(_):
beta_2=params["beta2"],
epsilon=params["epsilon"])
if FLAGS.dtype == "fp16":
optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
optimizer = \
tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment