"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "c0c500741554ec7e49aa58cbccf8f6a1c1a69a99"
Commit 5bf90fb5 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 397848064
parent 7f69eb3b
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Callable, Optional, Union from typing import Callable, Optional, Union, List, Tuple
import gin import gin
import tensorflow as tf import tensorflow as tf
...@@ -139,6 +139,9 @@ class OptimizerFactory: ...@@ -139,6 +139,9 @@ class OptimizerFactory:
def build_optimizer( def build_optimizer(
self, self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float], lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
gradient_transformers: Optional[List[Callable[
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor, tf.Tensor]]
]]] = None,
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer], postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer]] = None): tf.keras.optimizers.Optimizer]] = None):
"""Build optimizer. """Build optimizer.
...@@ -150,6 +153,11 @@ class OptimizerFactory: ...@@ -150,6 +153,11 @@ class OptimizerFactory:
Args: Args:
lr: A floating point value, or a lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance. tf.keras.optimizers.schedules.LearningRateSchedule instance.
gradient_transformers: Optional list of functions to use to transform
gradients before applying updates to Variables. The functions are
applied after gradient_aggregator. The functions should accept and
return a list of (gradient, variable) tuples. clipvalue, clipnorm,
global_clipnorm should not be set when gradient_transformers is passed.
postprocessor: An optional function for postprocessing the optimizer. It postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer. takes an optimizer and returns an optimizer.
...@@ -158,13 +166,17 @@ class OptimizerFactory: ...@@ -158,13 +166,17 @@ class OptimizerFactory:
""" """
optimizer_dict = self._optimizer_config.as_dict() optimizer_dict = self._optimizer_config.as_dict()
## Delete clipnorm and clipvalue if None ## Delete clipnorm, clipvalue, global_clipnorm if None
if optimizer_dict['clipnorm'] is None: if optimizer_dict['clipnorm'] is None:
del optimizer_dict['clipnorm'] del optimizer_dict['clipnorm']
if optimizer_dict['clipvalue'] is None: if optimizer_dict['clipvalue'] is None:
del optimizer_dict['clipvalue'] del optimizer_dict['clipvalue']
if optimizer_dict['global_clipnorm'] is None:
del optimizer_dict['global_clipnorm']
optimizer_dict['learning_rate'] = lr optimizer_dict['learning_rate'] = lr
if gradient_transformers is not None:
optimizer_dict['gradient_transformers'] = gradient_transformers
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict) optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment