Commit 510736ba authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 324693165
parent 76b4d0e7
......@@ -106,6 +106,7 @@ class AdamWeightDecayConfig(base_config.Config):
weight_decay_rate: float = 0.0
include_in_weight_decay: Optional[List[str]] = None
exclude_from_weight_decay: Optional[List[str]] = None
gradient_clip_norm: float = 1.0
@dataclasses.dataclass
......
......@@ -130,13 +130,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None,
gradient_clip_norm=1.0,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self.gradient_clip_norm = gradient_clip_norm
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
logging.info('gradient_clip_norm=%f', gradient_clip_norm)
@classmethod
def from_config(cls, config):
......@@ -165,7 +168,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
name=None,
experimental_aggregate_gradients=True):
grads, tvars = list(zip(*grads_and_vars))
if experimental_aggregate_gradients:
if experimental_aggregate_gradients and self.gradient_clip_norm > 0.0:
# when experimental_aggregate_gradients = False, apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient
# and passed the allreduced grads_and_vars. For now, the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment