Commit 5d340ff3 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Optimizer change to get compatible for an incoming Keras optimizer migration.

PiperOrigin-RevId: 468557315
parent b81fe53a
...@@ -683,8 +683,14 @@ class DistributedExecutor(object): ...@@ -683,8 +683,14 @@ class DistributedExecutor(object):
if not checkpoint_path: if not checkpoint_path:
raise ValueError('checkpoint path is empty') raise ValueError('checkpoint path is empty')
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path) reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
current_step = reader.get_tensor( if reader.has_tensor('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'):
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE') # Legacy keras optimizer iteration.
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
else:
# New keras optimizer iteration.
current_step = reader.get_tensor(
'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE')
logging.info('Checkpoint file %s found and restoring from ' logging.info('Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path) 'checkpoint', checkpoint_path)
status = checkpoint.restore(checkpoint_path) status = checkpoint.restore(checkpoint_path)
......
...@@ -92,7 +92,7 @@ def create_optimizer(init_lr, ...@@ -92,7 +92,7 @@ def create_optimizer(init_lr,
include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"]) include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"])
else: else:
logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon)) logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon))
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=learning_rate_fn, epsilon=adam_epsilon) learning_rate=learning_rate_fn, epsilon=adam_epsilon)
return optimizer, learning_rate_fn return optimizer, learning_rate_fn
...@@ -69,6 +69,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -69,6 +69,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
zero_grads = lambda gv: [(tf.zeros_like(g), v) for g, v in gv] zero_grads = lambda gv: [(tf.zeros_like(g), v) for g, v in gv]
optimizer = opt_factory.build_optimizer(lr, gradient_aggregator=zero_grads) optimizer = opt_factory.build_optimizer(lr, gradient_aggregator=zero_grads)
if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
self.skipTest('New Keras optimizer does not support '
'`gradient_aggregator` arg.')
var0 = tf.Variable([1.0, 2.0]) var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0]) var1 = tf.Variable([3.0, 4.0])
......
...@@ -83,7 +83,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -83,7 +83,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
functools.partial(task.build_inputs, config.train_data)) functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset) iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model")) model.save(os.path.join(self.get_temp_dir(), "saved_model"))
return task.validation_step(next(iterator), model, metrics=metrics) return task.validation_step(next(iterator), model, metrics=metrics)
...@@ -120,7 +120,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -120,7 +120,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
iterator = iter(dataset) iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
task.initialize(model) task.initialize(model)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
...@@ -151,7 +151,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -151,7 +151,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
iterator = iter(dataset) iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
logs = task.validation_step(next(iterator), model, metrics=metrics) logs = task.validation_step(next(iterator), model, metrics=metrics)
......
...@@ -120,9 +120,13 @@ class Trainer(tf.keras.Model): ...@@ -120,9 +120,13 @@ class Trainer(tf.keras.Model):
tvars = self.trainable_variables tvars = self.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
self.optimizer.apply_gradients(list(zip(grads, tvars))) self.optimizer.apply_gradients(list(zip(grads, tvars)))
if isinstance(self.optimizer, tf.keras.optimizers.experimental.Optimizer):
learning_rate = self.optimizer.learning_rate
else:
learning_rate = self.optimizer._decayed_lr(var_dtype=tf.float32)
return { return {
"training_loss": loss, "training_loss": loss,
"learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32) "learning_rate": learning_rate,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment