Commit 3b847220 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332095189
parent e6ffa057
...@@ -26,15 +26,15 @@ class OptimizerConfigTest(tf.test.TestCase): ...@@ -26,15 +26,15 @@ class OptimizerConfigTest(tf.test.TestCase):
def test_no_optimizer(self): def test_no_optimizer(self):
optimizer = optimization_config.OptimizationConfig({}).optimizer.get() optimizer = optimization_config.OptimizationConfig({}).optimizer.get()
self.assertEqual(optimizer, None) self.assertIsNone(optimizer)
def test_no_lr_schedule(self): def test_no_lr_schedule(self):
lr = optimization_config.OptimizationConfig({}).learning_rate.get() lr = optimization_config.OptimizationConfig({}).learning_rate.get()
self.assertEqual(lr, None) self.assertIsNone(lr)
def test_no_warmup_schedule(self): def test_no_warmup_schedule(self):
warmup = optimization_config.OptimizationConfig({}).warmup.get() warmup = optimization_config.OptimizationConfig({}).warmup.get()
self.assertEqual(warmup, None) self.assertIsNone(warmup)
def test_config(self): def test_config(self):
opt_config = optimization_config.OptimizationConfig({ opt_config = optimization_config.OptimizationConfig({
......
...@@ -21,7 +21,21 @@ from official.modeling.hyperparams import base_config ...@@ -21,7 +21,21 @@ from official.modeling.hyperparams import base_config
@dataclasses.dataclass @dataclasses.dataclass
class SGDConfig(base_config.Config): class BaseOptimizerConfig(base_config.Config):
"""Base optimizer config.
Attributes:
clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
their L2 norm exceeds this value.
clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
their absolute value exceeds this value.
"""
clipnorm: Optional[float] = None
clipvalue: Optional[float] = None
@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
"""Configuration for SGD optimizer. """Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD. The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
...@@ -39,7 +53,7 @@ class SGDConfig(base_config.Config): ...@@ -39,7 +53,7 @@ class SGDConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class RMSPropConfig(base_config.Config): class RMSPropConfig(BaseOptimizerConfig):
"""Configuration for RMSProp optimizer. """Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of The attributes for this class matches the arguments of
...@@ -60,7 +74,7 @@ class RMSPropConfig(base_config.Config): ...@@ -60,7 +74,7 @@ class RMSPropConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class AdamConfig(base_config.Config): class AdamConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer. """Configuration for Adam optimizer.
The attributes for this class matches the arguments of The attributes for this class matches the arguments of
...@@ -82,7 +96,7 @@ class AdamConfig(base_config.Config): ...@@ -82,7 +96,7 @@ class AdamConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class AdamWeightDecayConfig(base_config.Config): class AdamWeightDecayConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay. """Configuration for Adam optimizer with weight decay.
Attributes: Attributes:
...@@ -110,7 +124,7 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -110,7 +124,7 @@ class AdamWeightDecayConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class LAMBConfig(base_config.Config): class LAMBConfig(BaseOptimizerConfig):
"""Configuration for LAMB optimizer. """Configuration for LAMB optimizer.
The attributes for this class matches the arguments of The attributes for this class matches the arguments of
...@@ -139,7 +153,7 @@ class LAMBConfig(base_config.Config): ...@@ -139,7 +153,7 @@ class LAMBConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class EMAConfig(base_config.Config): class EMAConfig(BaseOptimizerConfig):
"""Exponential moving average optimizer config. """Exponential moving average optimizer config.
Attributes: Attributes:
......
...@@ -144,6 +144,12 @@ class OptimizerFactory(object): ...@@ -144,6 +144,12 @@ class OptimizerFactory(object):
""" """
optimizer_dict = self._optimizer_config.as_dict() optimizer_dict = self._optimizer_config.as_dict()
## Delete clipnorm and clipvalue if None
if optimizer_dict['clipnorm'] is None:
del optimizer_dict['clipnorm']
if optimizer_dict['clipvalue'] is None:
del optimizer_dict['clipvalue']
optimizer_dict['learning_rate'] = lr optimizer_dict['learning_rate'] = lr
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict) optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for optimizer_factory.py.""" """Tests for optimizer_factory.py."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling.optimization import optimizer_factory from official.modeling.optimization import optimizer_factory
...@@ -50,6 +49,49 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -50,6 +49,49 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
@parameterized.parameters(
(None, None),
(1.0, None),
(None, 1.0))
def test_gradient_clipping(self, clipnorm, clipvalue):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {
'clipnorm': clipnorm,
'clipvalue': clipvalue
}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 1.0
}
}
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([2.0, 3.0])
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
optimizer.apply_gradients(grads_and_vars)
self.assertAllClose(np.array([0.9, 1.9]), var0.numpy())
if clipvalue is not None:
self.assertAllClose(np.array([2.0, 3.0]), var1.numpy())
elif clipnorm is not None:
self.assertAllClose(np.array([2.4452999, 3.1679497]), var1.numpy())
else:
self.assertAllClose(np.array([1.0, 1.0]), var1.numpy())
def test_missing_types(self): def test_missing_types(self):
params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}} params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment