Commit f006521b authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 372242256
parent e7c57743
......@@ -370,7 +370,13 @@ class Trainer(_AsyncTrainer):
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if hasattr(self.optimizer, "iterations"):
logs["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
else:
logs["learning_rate"] = self.optimizer.learning_rate
return logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment