Commit a6f9945a authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use tf.compat.v2 when accessing policies.

I soon plan on making mixed precision policies exposed in TF2 only.

PiperOrigin-RevId: 266433372
parent e170a8ba
......@@ -169,9 +169,9 @@ class TransformerTask(object):
# this.
loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic")
policy = tf.keras.mixed_precision.experimental.Policy(
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......@@ -425,10 +425,11 @@ class TransformerTask(object):
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt
......
......@@ -72,10 +72,11 @@ class TransformerTaskTest(tf.test.TestCase):
self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
self.bleu_source = os.path.join(temp_dir, 'bleu_source')
self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
self.orig_policy = tf.keras.mixed_precision.experimental.global_policy()
self.orig_policy = (
tf.compat.v2.keras.mixed_precision.experimental.global_policy())
def tearDown(self):
tf.keras.mixed_precision.experimental.set_policy(self.orig_policy)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy)
def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath))
......
......@@ -96,9 +96,9 @@ def run(flags_obj):
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16',
loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if not keras_utils.is_v2_0():
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
......@@ -183,11 +183,13 @@ def run(flags_obj):
with strategy_scope:
optimizer = common.get_optimizer(lr_schedule)
if flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite",
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment