"examples/vscode:/vscode.git/clone" did not exist on "97baba1b2cf376198fbd383bdceec53dbc932c01"
Commit 68deb504 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Optimizer change to get compatible for an incoming Keras optimizer migration.

PiperOrigin-RevId: 468557315
parent ab9cb561
......@@ -683,8 +683,14 @@ class DistributedExecutor(object):
if not checkpoint_path:
raise ValueError('checkpoint path is empty')
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
if reader.has_tensor('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'):
# Legacy keras optimizer iteration.
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
else:
# New keras optimizer iteration.
current_step = reader.get_tensor(
'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE')
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path)
status = checkpoint.restore(checkpoint_path)
......
......@@ -92,7 +92,7 @@ def create_optimizer(init_lr,
include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"])
else:
logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon))
optimizer = tf.keras.optimizers.Adam(
optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=learning_rate_fn, epsilon=adam_epsilon)
return optimizer, learning_rate_fn
......@@ -69,6 +69,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
zero_grads = lambda gv: [(tf.zeros_like(g), v) for g, v in gv]
optimizer = opt_factory.build_optimizer(lr, gradient_aggregator=zero_grads)
if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
self.skipTest('New Keras optimizer does not support '
'`gradient_aggregator` arg.')
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
......
......@@ -83,7 +83,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
return task.validation_step(next(iterator), model, metrics=metrics)
......@@ -120,7 +120,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
task.initialize(model)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
......@@ -151,7 +151,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
logs = task.validation_step(next(iterator), model, metrics=metrics)
......
......@@ -120,9 +120,13 @@ class Trainer(tf.keras.Model):
tvars = self.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
self.optimizer.apply_gradients(list(zip(grads, tvars)))
if isinstance(self.optimizer, tf.keras.optimizers.experimental.Optimizer):
learning_rate = self.optimizer.learning_rate
else:
learning_rate = self.optimizer._decayed_lr(var_dtype=tf.float32)
return {
"training_loss": loss,
"learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32)
"learning_rate": learning_rate,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment