"tests/vscode:/vscode.git/clone" did not exist on "9a9d53fba3c610c793d75373dff1983de85e638d"
Commit 73e05832 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Code changes to get ready for an incoming Keras optimizer migration.

PiperOrigin-RevId: 476479516
parent c0525d49
...@@ -319,7 +319,9 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -319,7 +319,9 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertFalse(trainer.optimizer.dynamic) self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale) self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else: else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
......
...@@ -226,9 +226,13 @@ class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -226,9 +226,13 @@ class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
task = TestPolicy(None, config.task) task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir()) trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16': if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
elif mixed_precision_dtype == 'float16' and loss_scale is None: elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment