Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0a6f6426
Commit
0a6f6426
authored
Jun 17, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jun 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 316962972
parent
57e7ca73
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
63 deletions
+41
-63
official/modeling/optimization/configs/optimization_config.py
...cial/modeling/optimization/configs/optimization_config.py
+2
-0
official/modeling/optimization/configs/optimizer_config.py
official/modeling/optimization/configs/optimizer_config.py
+23
-0
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+2
-2
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+14
-61
No files found.
official/modeling/optimization/configs/optimization_config.py
View file @
0a6f6426
...
@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig):
...
@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig):
adam: adam optimizer config.
adam: adam optimizer config.
adamw: adam with weight decay.
adamw: adam with weight decay.
lamb: lamb optimizer.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
"""
"""
type
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
sgd
:
opt_cfg
.
SGDConfig
=
opt_cfg
.
SGDConfig
()
sgd
:
opt_cfg
.
SGDConfig
=
opt_cfg
.
SGDConfig
()
adam
:
opt_cfg
.
AdamConfig
=
opt_cfg
.
AdamConfig
()
adam
:
opt_cfg
.
AdamConfig
=
opt_cfg
.
AdamConfig
()
adamw
:
opt_cfg
.
AdamWeightDecayConfig
=
opt_cfg
.
AdamWeightDecayConfig
()
adamw
:
opt_cfg
.
AdamWeightDecayConfig
=
opt_cfg
.
AdamWeightDecayConfig
()
lamb
:
opt_cfg
.
LAMBConfig
=
opt_cfg
.
LAMBConfig
()
lamb
:
opt_cfg
.
LAMBConfig
=
opt_cfg
.
LAMBConfig
()
rmsprop
:
opt_cfg
.
RMSPropConfig
=
opt_cfg
.
RMSPropConfig
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/configs/optimizer_config.py
View file @
0a6f6426
...
@@ -40,6 +40,29 @@ class SGDConfig(base_config.Config):
...
@@ -40,6 +40,29 @@ class SGDConfig(base_config.Config):
momentum
:
float
=
0.0
momentum
:
float
=
0.0
@
dataclasses
.
dataclass
class
RMSPropConfig
(
base_config
.
Config
):
"""Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of
tf.keras.optimizers.RMSprop.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name
:
str
=
"RMSprop"
learning_rate
:
float
=
0.001
rho
:
float
=
0.9
momentum
:
float
=
0.0
epsilon
:
float
=
1e-7
centered
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
AdamConfig
(
base_config
.
Config
):
class
AdamConfig
(
base_config
.
Config
):
"""Configuration for Adam optimizer.
"""Configuration for Adam optimizer.
...
...
official/modeling/optimization/optimizer_factory.py
View file @
0a6f6426
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Optimizer factory class."""
"""Optimizer factory class."""
from
typing
import
Union
from
typing
import
Union
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = {
...
@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = {
'sgd'
:
tf
.
keras
.
optimizers
.
SGD
,
'sgd'
:
tf
.
keras
.
optimizers
.
SGD
,
'adam'
:
tf
.
keras
.
optimizers
.
Adam
,
'adam'
:
tf
.
keras
.
optimizers
.
Adam
,
'adamw'
:
nlp_optimization
.
AdamWeightDecay
,
'adamw'
:
nlp_optimization
.
AdamWeightDecay
,
'lamb'
:
tfa_optimizers
.
LAMB
'lamb'
:
tfa_optimizers
.
LAMB
,
'rmsprop'
:
tf
.
keras
.
optimizers
.
RMSprop
}
}
LR_CLS
=
{
LR_CLS
=
{
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
0a6f6426
...
@@ -15,84 +15,37 @@
...
@@ -15,84 +15,37 @@
# ==============================================================================
# ==============================================================================
"""Tests for optimizer_factory.py."""
"""Tests for optimizer_factory.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons.optimizers
as
tfa_optimizers
from
official.modeling.optimization
import
optimizer_factory
from
official.modeling.optimization
import
optimizer_factory
from
official.modeling.optimization.configs
import
optimization_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.nlp
import
optimization
as
nlp_optimization
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
):
def
test_sgd_optimizer
(
self
):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'learning_rate'
:
0.1
,
'momentum'
:
0.9
}
}
}
expected_optimizer_config
=
{
'name'
:
'SGD'
,
'learning_rate'
:
0.1
,
'decay'
:
0.0
,
'momentum'
:
0.9
,
'nesterov'
:
False
}
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
optimizer
=
opt_factory
.
build_optimizer
(
lr
)
self
.
assertIsInstance
(
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_adam_optimizer
(
self
):
# Define adam optimizer with default values.
params
=
{
'optimizer'
:
{
'type'
:
'adam'
}
}
expected_optimizer_config
=
tf
.
keras
.
optimizers
.
Adam
().
get_config
()
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
optimizer
=
opt_factory
.
build_optimizer
(
lr
)
self
.
assertIsInstance
(
optimizer
,
tf
.
keras
.
optimizers
.
Adam
)
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_adam_weight_decay_optimizer
(
self
):
@
parameterized
.
parameters
(
(
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
))
def
test_optimizers
(
self
,
optimizer_type
):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'adamw'
'type'
:
optimizer_type
}
}
}
}
expected_optimizer_config
=
nlp_optimization
.
AdamWeightDecay
().
get_config
()
optimizer_cls
=
optimizer_factory
.
OPTIMIZERS_CLS
[
optimizer_type
]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
expected_optimizer_config
=
optimizer_cls
().
get_config
()
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
optimizer
=
opt_factory
.
build_optimizer
(
lr
)
self
.
assertIsInstance
(
optimizer
,
nlp_optimization
.
AdamWeightDecay
)
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_lamb_optimizer
(
self
):
params
=
{
'optimizer'
:
{
'type'
:
'lamb'
}
}
expected_optimizer_config
=
tfa_optimizers
.
LAMB
().
get_config
()
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
optimizer
=
opt_factory
.
build_optimizer
(
lr
)
optimizer
=
opt_factory
.
build_optimizer
(
lr
)
self
.
assertIsInstance
(
optimizer
,
tfa_
optimizer
s
.
LAMB
)
self
.
assertIsInstance
(
optimizer
,
optimizer
_cls
)
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_stepwise_lr_schedule
(
self
):
def
test_stepwise_lr_schedule
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment