Unverified Commit 4f1788b3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix AdamWeightDecay for TF 2.11 (#20735)

* Fix AdamWeightDecay for TF

* Fix AdamWeightDecay for TF

* make fixup
parent a12c5cbc
...@@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union ...@@ -21,6 +21,12 @@ from typing import Callable, List, Optional, Union
import tensorflow as tf import tensorflow as tf
if hasattr(tf.keras, "optimizer") and hasattr(tf.keras.optimizer, "legacy"):
Adam = tf.keras.optimizer.legacy.Adam
else:
Adam = tf.keras.optimizers.Adam
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
""" """
Applies a warmup schedule on a given learning rate decay schedule. Applies a warmup schedule on a given learning rate decay schedule.
...@@ -163,7 +169,7 @@ def create_optimizer( ...@@ -163,7 +169,7 @@ def create_optimizer(
return optimizer, lr_schedule return optimizer, lr_schedule
class AdamWeightDecay(tf.keras.optimizers.Adam): class AdamWeightDecay(Adam):
""" """
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment