Commit 73e05832 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Code changes to get ready for an incoming Keras optimizer migration.

PiperOrigin-RevId: 476479516
parent c0525d49
......@@ -319,7 +319,9 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
......
......@@ -226,9 +226,13 @@ class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment