Commit a4235e26 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Optimizer change to get compatible for an incoming Keras optimizer migration.

PiperOrigin-RevId: 477217857
parent 10673875
...@@ -238,6 +238,9 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task): ...@@ -238,6 +238,9 @@ class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
}) })
opt_factory = optimization.OptimizerFactory(params) opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
optimizer = tf.keras.__internal__.optimizers.convert_to_legacy_optimizer(
optimizer)
return optimizer return optimizer
......
...@@ -153,7 +153,7 @@ class DistillationTest(tf.test.TestCase, parameterized.TestCase): ...@@ -153,7 +153,7 @@ class DistillationTest(tf.test.TestCase, parameterized.TestCase):
eval_dataset = bert_distillation_task.get_eval_dataset(stage_id=0) eval_dataset = bert_distillation_task.get_eval_dataset(stage_id=0)
eval_iterator = iter(eval_dataset) eval_iterator = iter(eval_dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
# test train/val step for all stages, including the last pretraining stage # test train/val step for all stages, including the last pretraining stage
for stage in range(student_block_num + 1): for stage in range(student_block_num + 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment