Commit dab0c03a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264474346
parent 8089a561
...@@ -105,12 +105,14 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -105,12 +105,14 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
epsilon=1e-7, epsilon=1e-7,
amsgrad=False, amsgrad=False,
weight_decay_rate=0.0, weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None, exclude_from_weight_decay=None,
name='AdamWeightDecay', name='AdamWeightDecay',
**kwargs): **kwargs):
super(AdamWeightDecay, self).__init__( super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod @classmethod
...@@ -178,6 +180,12 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -178,6 +180,12 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
"""Whether to use L2 weight decay for `param_name`.""" """Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0: if self.weight_decay_rate == 0:
return False return False
if self._include_in_weight_decay:
for r in self._include_in_weight_decay:
if re.search(r, param_name) is not None:
return True
if self._exclude_from_weight_decay: if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay: for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None: if re.search(r, param_name) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment