"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "e0c17eea7eca16cdf5dc79b933b94e94ea2c804b"
Commit 510736ba authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 324693165
parent 76b4d0e7
...@@ -106,6 +106,7 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -106,6 +106,7 @@ class AdamWeightDecayConfig(base_config.Config):
weight_decay_rate: float = 0.0 weight_decay_rate: float = 0.0
include_in_weight_decay: Optional[List[str]] = None include_in_weight_decay: Optional[List[str]] = None
exclude_from_weight_decay: Optional[List[str]] = None exclude_from_weight_decay: Optional[List[str]] = None
gradient_clip_norm: float = 1.0
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -130,13 +130,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -130,13 +130,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
weight_decay_rate=0.0, weight_decay_rate=0.0,
include_in_weight_decay=None, include_in_weight_decay=None,
exclude_from_weight_decay=None, exclude_from_weight_decay=None,
gradient_clip_norm=1.0,
name='AdamWeightDecay', name='AdamWeightDecay',
**kwargs): **kwargs):
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2, super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
epsilon, amsgrad, name, **kwargs) epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate self.weight_decay_rate = weight_decay_rate
self.gradient_clip_norm = gradient_clip_norm
self._include_in_weight_decay = include_in_weight_decay self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay self._exclude_from_weight_decay = exclude_from_weight_decay
logging.info('gradient_clip_norm=%f', gradient_clip_norm)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -165,7 +168,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -165,7 +168,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
name=None, name=None,
experimental_aggregate_gradients=True): experimental_aggregate_gradients=True):
grads, tvars = list(zip(*grads_and_vars)) grads, tvars = list(zip(*grads_and_vars))
if experimental_aggregate_gradients: if experimental_aggregate_gradients and self.gradient_clip_norm > 0.0:
# when experimental_aggregate_gradients = False, apply_gradients() no # when experimental_aggregate_gradients = False, apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient # longer implicitly allreduce gradients, users manually allreduce gradient
# and passed the allreduced grads_and_vars. For now, the # and passed the allreduced grads_and_vars. For now, the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment