Commit 5bf90fb5 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 397848064
parent 7f69eb3b
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Optimizer factory class."""
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, List, Tuple
import gin
import tensorflow as tf
......@@ -139,6 +139,9 @@ class OptimizerFactory:
def build_optimizer(
self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
gradient_transformers: Optional[List[Callable[
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor, tf.Tensor]]
]]] = None,
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer]] = None):
"""Build optimizer.
......@@ -150,6 +153,11 @@ class OptimizerFactory:
Args:
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
gradient_transformers: Optional list of functions to use to transform
gradients before applying updates to Variables. The functions are
applied after gradient_aggregator. The functions should accept and
return a list of (gradient, variable) tuples. clipvalue, clipnorm,
global_clipnorm should not be set when gradient_transformers is passed.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
......@@ -158,13 +166,17 @@ class OptimizerFactory:
"""
optimizer_dict = self._optimizer_config.as_dict()
## Delete clipnorm and clipvalue if None
## Delete clipnorm, clipvalue, global_clipnorm if None
if optimizer_dict['clipnorm'] is None:
del optimizer_dict['clipnorm']
if optimizer_dict['clipvalue'] is None:
del optimizer_dict['clipvalue']
if optimizer_dict['global_clipnorm'] is None:
del optimizer_dict['global_clipnorm']
optimizer_dict['learning_rate'] = lr
if gradient_transformers is not None:
optimizer_dict['gradient_transformers'] = gradient_transformers
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment