"vscode:/vscode.git/clone" did not exist on "02b176c4ce14340d26d42825523f406959c6c202"
Unverified Commit 349a6e85 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix Keras scheduler import so it works for older versions of Keras (#28895)

Fix our schedule import so it works for older versions of Keras
parent d9deddb4
......@@ -29,7 +29,14 @@ except (ImportError, ModuleNotFoundError):
from .modeling_tf_utils import keras
class WarmUp(keras.optimizers.schedules.LearningRateSchedule):
# This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15
if hasattr(keras.optimizers.schedules, "learning_rate_schedule"):
schedules = keras.optimizers.schedules.learning_rate_schedule
else:
schedules = keras.optimizers.schedules
class WarmUp(schedules.LearningRateSchedule):
"""
Applies a warmup schedule on a given learning rate decay schedule.
......@@ -133,7 +140,7 @@ def create_optimizer(
applied to all parameters except bias and layer norm parameters.
"""
# Implements linear decay of the learning rate.
lr_schedule = keras.optimizers.schedules.PolynomialDecay(
lr_schedule = schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps - num_warmup_steps,
end_learning_rate=init_lr * min_lr_ratio,
......@@ -182,7 +189,7 @@ class AdamWeightDecay(Adam):
to adding the square of the weights to the loss with plain (non-momentum) SGD.
Args:
learning_rate (`Union[float, keras.optimizers.schedules.LearningRateSchedule]`, *optional*, defaults to 0.001):
learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001):
The learning rate to use or a schedule.
beta_1 (`float`, *optional*, defaults to 0.9):
The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
......@@ -212,7 +219,7 @@ class AdamWeightDecay(Adam):
def __init__(
self,
learning_rate: Union[float, keras.optimizers.schedules.LearningRateSchedule] = 0.001,
learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001,
beta_1: float = 0.9,
beta_2: float = 0.999,
epsilon: float = 1e-7,
......@@ -301,7 +308,7 @@ class AdamWeightDecay(Adam):
# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class GradientAccumulator(object):
class GradientAccumulator:
"""
Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a
replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment