Commit a6f9945a authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use tf.compat.v2 when accessing policies.

I soon plan on making mixed precision policies exposed in TF2 only.

PiperOrigin-RevId: 266433372
parent e170a8ba
...@@ -169,9 +169,9 @@ class TransformerTask(object): ...@@ -169,9 +169,9 @@ class TransformerTask(object):
# this. # this.
loss_scale = flags_core.get_loss_scale(flags_obj, loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic") default_for_fp16="dynamic")
policy = tf.keras.mixed_precision.experimental.Policy( policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale) "mixed_float16", loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
...@@ -425,10 +425,11 @@ class TransformerTask(object): ...@@ -425,10 +425,11 @@ class TransformerTask(object):
opt, loss_scale=flags_core.get_loss_scale(self.flags_obj, opt, loss_scale=flags_core.get_loss_scale(self.flags_obj,
default_for_fp16="dynamic")) default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite": if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite", # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite # which will ensure tf.compat.v2.keras.mixed_precision and
# do not double up. # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt) opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
...@@ -72,10 +72,11 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -72,10 +72,11 @@ class TransformerTaskTest(tf.test.TestCase):
self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size'] self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
self.bleu_source = os.path.join(temp_dir, 'bleu_source') self.bleu_source = os.path.join(temp_dir, 'bleu_source')
self.bleu_ref = os.path.join(temp_dir, 'bleu_ref') self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
self.orig_policy = tf.keras.mixed_precision.experimental.global_policy() self.orig_policy = (
tf.compat.v2.keras.mixed_precision.experimental.global_policy())
def tearDown(self): def tearDown(self):
tf.keras.mixed_precision.experimental.set_policy(self.orig_policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy)
def _assert_exists(self, filepath): def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
......
...@@ -96,9 +96,9 @@ def run(flags_obj): ...@@ -96,9 +96,9 @@ def run(flags_obj):
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16': if dtype == 'float16':
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16', policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
loss_scale=loss_scale) 'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if not keras_utils.is_v2_0(): if not keras_utils.is_v2_0():
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.') raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
...@@ -183,11 +183,13 @@ def run(flags_obj): ...@@ -183,11 +183,13 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
if flags_obj.fp16_implementation == "graph_rewrite": if flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite", # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# dtype as determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite # which will ensure tf.compat.v2.keras.mixed_precision and
# do not double up. # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) # up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# TODO(hongkuny): Remove trivial model usage and move it to benchmark. # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment