Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eb6fa0b2
Commit
eb6fa0b2
authored
Apr 07, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 305380592
parent
98074f7a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
381 additions
and
36 deletions
+381
-36
official/vision/image_classification/callbacks.py
official/vision/image_classification/callbacks.py
+114
-6
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+14
-12
official/vision/image_classification/configs/base_configs.py
official/vision/image_classification/configs/base_configs.py
+5
-0
official/vision/image_classification/configs/configs.py
official/vision/image_classification/configs/configs.py
+4
-2
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
+2
-0
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
+1
-1
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
+2
-0
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
+2
-0
official/vision/image_classification/efficientnet/common_modules.py
...ision/image_classification/efficientnet/common_modules.py
+2
-3
official/vision/image_classification/efficientnet/efficientnet_config.py
.../image_classification/efficientnet/efficientnet_config.py
+2
-1
official/vision/image_classification/optimizer_factory.py
official/vision/image_classification/optimizer_factory.py
+230
-8
official/vision/image_classification/optimizer_factory_test.py
...ial/vision/image_classification/optimizer_factory_test.py
+3
-3
No files found.
official/vision/image_classification/callbacks.py
View file @
eb6fa0b2
...
@@ -20,12 +20,12 @@ from __future__ import division
...
@@ -20,12 +20,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
from
typing
import
Any
,
List
,
MutableMapping
,
Text
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
typing
import
Any
,
List
,
MutableMapping
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
optimizer_factory
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
...
@@ -33,6 +33,7 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -33,6 +33,7 @@ def get_callbacks(model_checkpoint: bool = True,
time_history
:
bool
=
True
,
time_history
:
bool
=
True
,
track_lr
:
bool
=
True
,
track_lr
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
apply_moving_average
:
bool
=
False
,
initial_step
:
int
=
0
,
initial_step
:
int
=
0
,
batch_size
:
int
=
0
,
batch_size
:
int
=
0
,
log_steps
:
int
=
0
,
log_steps
:
int
=
0
,
...
@@ -42,9 +43,8 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -42,9 +43,8 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks
=
[]
callbacks
=
[]
if
model_checkpoint
:
if
model_checkpoint
:
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
if
include_tensorboard
:
if
include_tensorboard
:
callbacks
.
append
(
callbacks
.
append
(
CustomTensorBoard
(
CustomTensorBoard
(
...
@@ -58,6 +58,17 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -58,6 +58,17 @@ def get_callbacks(model_checkpoint: bool = True,
batch_size
,
batch_size
,
log_steps
,
log_steps
,
logdir
=
model_dir
if
include_tensorboard
else
None
))
logdir
=
model_dir
if
include_tensorboard
else
None
))
if
apply_moving_average
:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'average'
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
AverageModelCheckpoint
(
update_weights
=
False
,
filepath
=
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
callbacks
.
append
(
MovingAverageCallback
())
return
callbacks
return
callbacks
...
@@ -136,7 +147,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -136,7 +147,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def
_calculate_lr
(
self
)
->
int
:
def
_calculate_lr
(
self
)
->
int
:
"""Calculates the learning rate given the current step."""
"""Calculates the learning rate given the current step."""
return
get_scalar_from_tensor
(
return
get_scalar_from_tensor
(
self
.
_get_base_optimizer
().
_decayed_lr
(
var_dtype
=
tf
.
float32
))
self
.
_get_base_optimizer
().
_decayed_lr
(
var_dtype
=
tf
.
float32
))
# pylint:disable=protected-access
def
_get_base_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
def
_get_base_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""Get the base optimizer used by the current model."""
"""Get the base optimizer used by the current model."""
...
@@ -148,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -148,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
optimizer
=
optimizer
.
_optimizer
# pylint:disable=protected-access
optimizer
=
optimizer
.
_optimizer
# pylint:disable=protected-access
return
optimizer
return
optimizer
class
MovingAverageCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""A Callback to be used with a `MovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def
__init__
(
self
,
overwrite_weights_on_train_end
:
bool
=
False
,
**
kwargs
):
super
(
MovingAverageCallback
,
self
).
__init__
(
**
kwargs
)
self
.
overwrite_weights_on_train_end
=
overwrite_weights_on_train_end
def
set_model
(
self
,
model
:
tf
.
keras
.
Model
):
super
(
MovingAverageCallback
,
self
).
set_model
(
model
)
assert
isinstance
(
self
.
model
.
optimizer
,
optimizer_factory
.
MovingAverage
)
self
.
model
.
optimizer
.
shadow_copy
(
self
.
model
)
def
on_test_begin
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
self
.
model
.
optimizer
.
swap_weights
()
def
on_test_end
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
self
.
model
.
optimizer
.
swap_weights
()
def
on_train_end
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
if
self
.
overwrite_weights_on_train_end
:
self
.
model
.
optimizer
.
assign_average_vars
(
self
.
model
.
variables
)
class
AverageModelCheckpoint
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
def
__init__
(
self
,
update_weights
:
bool
,
filepath
:
str
,
monitor
:
str
=
'val_loss'
,
verbose
:
int
=
0
,
save_best_only
:
bool
=
False
,
save_weights_only
:
bool
=
False
,
mode
:
str
=
'auto'
,
save_freq
:
str
=
'epoch'
,
**
kwargs
):
self
.
update_weights
=
update_weights
super
().
__init__
(
filepath
,
monitor
,
verbose
,
save_best_only
,
save_weights_only
,
mode
,
save_freq
,
**
kwargs
)
def
set_model
(
self
,
model
):
if
not
isinstance
(
model
.
optimizer
,
optimizer_factory
.
MovingAverage
):
raise
TypeError
(
'AverageModelCheckpoint is only used when training'
'with MovingAverage'
)
return
super
().
set_model
(
model
)
def
_save_model
(
self
,
epoch
,
logs
):
assert
isinstance
(
self
.
model
.
optimizer
,
optimizer_factory
.
MovingAverage
)
if
self
.
update_weights
:
self
.
model
.
optimizer
.
assign_average_vars
(
self
.
model
.
variables
)
return
super
().
_save_model
(
epoch
,
logs
)
else
:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights
=
self
.
model
.
get_weights
()
self
.
model
.
optimizer
.
assign_average_vars
(
self
.
model
.
variables
)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result
=
super
().
_save_model
(
epoch
,
logs
)
self
.
model
.
set_weights
(
non_avg_weights
)
return
result
official/vision/image_classification/classifier_trainer.py
View file @
eb6fa0b2
...
@@ -360,18 +360,18 @@ def train_and_eval(
...
@@ -360,18 +360,18 @@ def train_and_eval(
model_dir
=
params
.
model_dir
,
model_dir
=
params
.
model_dir
,
train_steps
=
train_steps
)
train_steps
=
train_steps
)
callbacks
=
custom_callbacks
.
get_callbacks
(
model_checkpoint
=
params
.
train
.
callbacks
.
enable_checkpoint_and_export
,
include_tensorboard
=
params
.
train
.
callbacks
.
enable_tensorboard
,
time_history
=
params
.
train
.
callbacks
.
enable_time_history
,
track_lr
=
params
.
train
.
tensorboard
.
track_lr
,
write_model_weights
=
params
.
train
.
tensorboard
.
write_model_weights
,
initial_step
=
initial_epoch
*
train_steps
,
batch_size
=
train_builder
.
global_batch_size
,
log_steps
=
params
.
train
.
time_history
.
log_steps
,
model_dir
=
params
.
model_dir
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
# TODO(dankondratyuk): callbacks significantly slow down training
callbacks
=
custom_callbacks
.
get_callbacks
(
model_checkpoint
=
params
.
train
.
callbacks
.
enable_checkpoint_and_export
,
include_tensorboard
=
params
.
train
.
callbacks
.
enable_tensorboard
,
time_history
=
params
.
train
.
callbacks
.
enable_time_history
,
track_lr
=
params
.
train
.
tensorboard
.
track_lr
,
write_model_weights
=
params
.
train
.
tensorboard
.
write_model_weights
,
initial_step
=
initial_epoch
*
train_steps
,
batch_size
=
train_builder
.
global_batch_size
,
log_steps
=
params
.
train
.
time_history
.
log_steps
,
model_dir
=
params
.
model_dir
)
if
params
.
evaluation
.
skip_eval
:
if
params
.
evaluation
.
skip_eval
:
validation_kwargs
=
{}
validation_kwargs
=
{}
...
@@ -388,7 +388,9 @@ def train_and_eval(
...
@@ -388,7 +388,9 @@ def train_and_eval(
steps_per_epoch
=
train_steps
,
steps_per_epoch
=
train_steps
,
initial_epoch
=
initial_epoch
,
initial_epoch
=
initial_epoch
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
**
validation_kwargs
)
**
validation_kwargs
,
experimental_steps_per_execution
=
params
.
train
.
steps_per_loop
,
verbose
=
2
)
validation_output
=
None
validation_output
=
None
if
not
params
.
evaluation
.
skip_eval
:
if
not
params
.
evaluation
.
skip_eval
:
...
...
official/vision/image_classification/configs/base_configs.py
View file @
eb6fa0b2
...
@@ -82,6 +82,8 @@ class TrainConfig(base_config.Config):
...
@@ -82,6 +82,8 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig.
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig.
tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function`
call during training, which can increase training speed.
"""
"""
resume_checkpoint
:
bool
=
None
resume_checkpoint
:
bool
=
None
...
@@ -91,6 +93,7 @@ class TrainConfig(base_config.Config):
...
@@ -91,6 +93,7 @@ class TrainConfig(base_config.Config):
metrics
:
MetricsConfig
=
None
metrics
:
MetricsConfig
=
None
tensorboard
:
TensorboardConfig
=
TensorboardConfig
()
tensorboard
:
TensorboardConfig
=
TensorboardConfig
()
time_history
:
TimeHistoryConfig
=
TimeHistoryConfig
()
time_history
:
TimeHistoryConfig
=
TimeHistoryConfig
()
steps_per_loop
:
int
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -176,6 +179,7 @@ class LearningRateConfig(base_config.Config):
...
@@ -176,6 +179,7 @@ class LearningRateConfig(base_config.Config):
multipliers: multipliers used in piecewise constant decay with warmup.
multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch
scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default).
size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
"""
"""
name
:
str
=
None
name
:
str
=
None
...
@@ -187,6 +191,7 @@ class LearningRateConfig(base_config.Config):
...
@@ -187,6 +191,7 @@ class LearningRateConfig(base_config.Config):
boundaries
:
List
[
int
]
=
None
boundaries
:
List
[
int
]
=
None
multipliers
:
List
[
float
]
=
None
multipliers
:
List
[
float
]
=
None
scale_by_batch_size
:
float
=
0.
scale_by_batch_size
:
float
=
0.
staircase
:
bool
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/image_classification/configs/configs.py
View file @
eb6fa0b2
...
@@ -54,7 +54,8 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
...
@@ -54,7 +54,8 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
metrics
=
[
'accuracy'
,
'top_5'
],
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
write_model_weights
=
False
))
write_model_weights
=
False
),
steps_per_loop
=
1
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
epochs_between_evals
=
1
,
steps
=
None
)
steps
=
None
)
...
@@ -86,7 +87,8 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
...
@@ -86,7 +87,8 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
metrics
=
[
'accuracy'
,
'top_5'
],
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
write_model_weights
=
False
))
write_model_weights
=
False
),
steps_per_loop
=
1
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
epochs_between_evals
=
1
,
steps
=
None
)
steps
=
None
)
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
View file @
eb6fa0b2
...
@@ -40,6 +40,8 @@ model:
...
@@ -40,6 +40,8 @@ model:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
loss
:
loss
:
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
View file @
eb6fa0b2
...
@@ -39,7 +39,7 @@ model:
...
@@ -39,7 +39,7 @@ model:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.
moving_average_decay
:
0.
0
lookahead
:
false
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
View file @
eb6fa0b2
...
@@ -33,6 +33,8 @@ model:
...
@@ -33,6 +33,8 @@ model:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
loss
:
loss
:
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
View file @
eb6fa0b2
...
@@ -38,6 +38,8 @@ model:
...
@@ -38,6 +38,8 @@ model:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
loss
:
loss
:
...
...
official/vision/image_classification/efficientnet/common_modules.py
View file @
eb6fa0b2
...
@@ -19,15 +19,14 @@ from __future__ import division
...
@@ -19,15 +19,14 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
import
numpy
as
np
import
tensorflow.compat.v1
as
tf1
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow.compat.v1
as
tf1
from
typing
import
Text
,
Optional
from
typing
import
Text
,
Optional
from
tensorflow.python.tpu
import
tpu_function
from
tensorflow.python.tpu
import
tpu_function
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'
Text
'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'
Vision
'
)
class
TpuBatchNormalization
(
tf
.
keras
.
layers
.
BatchNormalization
):
class
TpuBatchNormalization
(
tf
.
keras
.
layers
.
BatchNormalization
):
"""Cross replica batch normalization."""
"""Cross replica batch normalization."""
...
...
official/vision/image_classification/efficientnet/efficientnet_config.py
View file @
eb6fa0b2
...
@@ -72,4 +72,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
...
@@ -72,4 +72,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
decay_epochs
=
2.4
,
decay_epochs
=
2.4
,
decay_rate
=
0.97
,
decay_rate
=
0.97
,
warmup_epochs
=
5
,
warmup_epochs
=
5
,
scale_by_batch_size
=
1.
/
128.
)
scale_by_batch_size
=
1.
/
128.
,
staircase
=
True
)
official/vision/image_classification/optimizer_factory.py
View file @
eb6fa0b2
...
@@ -22,10 +22,230 @@ from absl import logging
...
@@ -22,10 +22,230 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons
as
tfa
import
tensorflow_addons
as
tfa
from
typing
import
Any
,
Dict
,
Text
from
typing
import
Any
,
Dict
,
Text
,
List
from
official.vision.image_classification
import
learning_rate
from
official.vision.image_classification
import
learning_rate
from
official.vision.image_classification.configs
import
base_configs
from
official.vision.image_classification.configs
import
base_configs
# pylint: disable=protected-access
class
MovingAverage
(
tf
.
keras
.
optimizers
.
Optimizer
):
"""Optimizer that computes a moving average of the variables.
Empirically it has been found that using the moving average of the trained
parameters of a deep network is better than using its trained parameters
directly. This optimizer allows you to compute this moving average and swap
the variables at save time so that any code outside of the training loop
will use by default the average values instead of the original ones.
Example of usage for training:
```python
opt = tf.keras.optimizers.SGD(learning_rate)
opt = MovingAverage(opt)
opt.shadow_copy(model)
```
At test time, swap the shadow variables to evaluate on the averaged weights:
```python
opt.swap_weights()
# Test eval the model here
opt.swap_weights()
```
"""
def
__init__
(
self
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
average_decay
:
float
=
0.99
,
start_step
:
int
=
0
,
dynamic_decay
:
bool
=
True
,
name
:
Text
=
'moving_average'
,
**
kwargs
):
"""Construct a new MovingAverage optimizer.
Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number
of optimizer updates. Decay will start at 0.1 and gradually increase
up to `average_decay` after each optimizer update. This behavior is
similar to `tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying
gradients. Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}.
"""
super
(
MovingAverage
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_optimizer
=
optimizer
self
.
_average_decay
=
average_decay
self
.
_start_step
=
tf
.
constant
(
start_step
,
tf
.
float32
)
self
.
_dynamic_decay
=
dynamic_decay
def
shadow_copy
(
self
,
model
:
tf
.
keras
.
Model
):
"""Creates shadow variables for the given model weights."""
for
var
in
model
.
weights
:
self
.
add_slot
(
var
,
'average'
,
initializer
=
'zeros'
)
self
.
_average_weights
=
[
self
.
get_slot
(
var
,
'average'
)
for
var
in
model
.
weights
]
self
.
_model_weights
=
model
.
weights
@
property
def
has_shadow_copy
(
self
):
"""Whether this optimizer has created shadow variables."""
return
self
.
_model_weights
is
not
None
def
_create_slots
(
self
,
var_list
):
self
.
_optimizer
.
_create_slots
(
var_list
=
var_list
)
# pylint: disable=protected-access
def
apply_gradients
(
self
,
grads_and_vars
,
name
:
Text
=
None
):
result
=
self
.
_optimizer
.
apply_gradients
(
grads_and_vars
,
name
)
self
.
update_average
(
self
.
_optimizer
.
iterations
)
return
result
@
tf
.
function
def
update_average
(
self
,
step
:
tf
.
Tensor
):
step
=
tf
.
cast
(
step
,
tf
.
float32
)
if
step
<
self
.
_start_step
:
decay
=
tf
.
constant
(
0.
,
tf
.
float32
)
elif
self
.
_dynamic_decay
:
decay
=
step
-
self
.
_start_step
decay
=
tf
.
minimum
(
self
.
_average_decay
,
(
1.
+
decay
)
/
(
10.
+
decay
))
else
:
decay
=
self
.
_average_decay
def
_apply_moving
(
v_moving
,
v_normal
):
diff
=
v_moving
-
v_normal
v_moving
.
assign_sub
(
tf
.
cast
(
1.
-
decay
,
v_moving
.
dtype
)
*
diff
)
return
v_moving
def
_update
(
strategy
,
v_moving_and_v_normal
):
for
v_moving
,
v_normal
in
v_moving_and_v_normal
:
strategy
.
extended
.
update
(
v_moving
,
_apply_moving
,
args
=
(
v_normal
,))
ctx
=
tf
.
distribute
.
get_replica_context
()
return
ctx
.
merge_call
(
_update
,
args
=
(
zip
(
self
.
_average_weights
,
self
.
_model_weights
),))
def
swap_weights
(
self
):
"""Swap the average and moving weights.
This is a convenience method to allow one to evaluate the averaged weights
at test time. Loads the weights stored in `self._average` into the model,
keeping a copy of the original model weights. Swapping twice will return
the original weights.
"""
if
tf
.
distribute
.
in_cross_replica_context
():
strategy
=
tf
.
distribute
.
get_strategy
()
strategy
.
run
(
self
.
_swap_weights
,
args
=
())
else
:
raise
ValueError
(
'Swapping weights must occur under a '
'tf.distribute.Strategy'
)
@
tf
.
function
def
_swap_weights
(
self
):
def
fn_0
(
a
,
b
):
a
.
assign_add
(
b
)
return
a
def
fn_1
(
b
,
a
):
b
.
assign
(
a
-
b
)
return
b
def
fn_2
(
a
,
b
):
a
.
assign_sub
(
b
)
return
a
def
swap
(
strategy
,
a_and_b
):
"""Swap `a` and `b` and mirror to all devices."""
for
a
,
b
in
a_and_b
:
strategy
.
extended
.
update
(
a
,
fn_0
,
args
=
(
b
,))
# a = a + b
strategy
.
extended
.
update
(
b
,
fn_1
,
args
=
(
a
,))
# b = a - b
strategy
.
extended
.
update
(
a
,
fn_2
,
args
=
(
b
,))
# a = a - b
ctx
=
tf
.
distribute
.
get_replica_context
()
return
ctx
.
merge_call
(
swap
,
args
=
(
zip
(
self
.
_average_weights
,
self
.
_model_weights
),))
def
assign_average_vars
(
self
,
var_list
:
List
[
tf
.
Variable
]):
"""Assign variables in var_list with their respective averages.
Args:
var_list: List of model variables to be assigned to their average.
Returns:
assign_op: The op corresponding to the assignment operation of
variables to their average.
"""
assign_op
=
tf
.
group
([
var
.
assign
(
self
.
get_slot
(
var
,
'average'
))
for
var
in
var_list
if
var
.
trainable
])
return
assign_op
def
_create_hypers
(
self
):
self
.
_optimizer
.
_create_hypers
()
# pylint: disable=protected-access
def
_prepare
(
self
,
var_list
):
return
self
.
_optimizer
.
_prepare
(
var_list
=
var_list
)
# pylint: disable=protected-access
@
property
def
iterations
(
self
):
return
self
.
_optimizer
.
iterations
@
iterations
.
setter
def
iterations
(
self
,
variable
):
self
.
_optimizer
.
iterations
=
variable
@
property
def
weights
(
self
):
# return self._weights + self._optimizer.weights
return
self
.
_optimizer
.
weights
@
property
def
lr
(
self
):
return
self
.
_optimizer
.
_get_hyper
(
'learning_rate'
)
@
lr
.
setter
def
lr
(
self
,
lr
):
self
.
_optimizer
.
_set_hyper
(
'learning_rate'
,
lr
)
@
property
def
learning_rate
(
self
):
return
self
.
_optimizer
.
_get_hyper
(
'learning_rate'
)
@
learning_rate
.
setter
def
learning_rate
(
self
,
learning_rate
):
# pylint: disable=redefined-outer-name
self
.
_optimizer
.
_set_hyper
(
'learning_rate'
,
learning_rate
)
def
_resource_apply_dense
(
self
,
grad
,
var
):
return
self
.
_optimizer
.
_resource_apply_dense
(
grad
,
var
)
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
):
return
self
.
_optimizer
.
_resource_apply_sparse
(
grad
,
var
,
indices
)
def
_resource_apply_sparse_duplicate_indices
(
self
,
grad
,
var
,
indices
):
return
self
.
_optimizer
.
_resource_apply_sparse_duplicate_indices
(
grad
,
var
,
indices
)
def
get_config
(
self
):
config
=
{
'optimizer'
:
tf
.
keras
.
optimizers
.
serialize
(
self
.
_optimizer
),
'average_decay'
:
self
.
_average_decay
,
'start_step'
:
self
.
_start_step
,
'dynamic_decay'
:
self
.
_dynamic_decay
,
}
base_config
=
super
(
MovingAverage
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
optimizer
=
tf
.
keras
.
optimizers
.
deserialize
(
config
.
pop
(
'optimizer'
),
custom_objects
=
custom_objects
,
)
return
cls
(
optimizer
,
**
config
)
def
build_optimizer
(
def
build_optimizer
(
optimizer_name
:
Text
,
optimizer_name
:
Text
,
...
@@ -95,16 +315,17 @@ def build_optimizer(
...
@@ -95,16 +315,17 @@ def build_optimizer(
else
:
else
:
raise
ValueError
(
'Unknown optimizer %s'
%
optimizer_name
)
raise
ValueError
(
'Unknown optimizer %s'
%
optimizer_name
)
if
params
.
get
(
'lookahead'
,
None
):
logging
.
info
(
'Using lookahead optimizer.'
)
optimizer
=
tfa
.
optimizers
.
Lookahead
(
optimizer
)
# Moving average should be applied last, as it's applied at test time
moving_average_decay
=
params
.
get
(
'moving_average_decay'
,
0.
)
moving_average_decay
=
params
.
get
(
'moving_average_decay'
,
0.
)
if
moving_average_decay
is
not
None
and
moving_average_decay
>
0.
:
if
moving_average_decay
is
not
None
and
moving_average_decay
>
0.
:
logging
.
info
(
'Including moving average decay.'
)
logging
.
info
(
'Including moving average decay.'
)
optimizer
=
tfa
.
optimizers
.
MovingAverage
(
optimizer
=
MovingAverage
(
optimizer
,
optimizer
,
average_decay
=
params
[
'moving_average_decay'
],
average_decay
=
moving_average_decay
)
num_updates
=
None
)
if
params
.
get
(
'lookahead'
,
None
):
logging
.
info
(
'Using lookahead optimizer.'
)
optimizer
=
tfa
.
optimizers
.
Lookahead
(
optimizer
)
return
optimizer
return
optimizer
...
@@ -139,7 +360,8 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
...
@@ -139,7 +360,8 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
lr
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
lr
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
=
base_lr
,
initial_learning_rate
=
base_lr
,
decay_steps
=
decay_steps
,
decay_steps
=
decay_steps
,
decay_rate
=
decay_rate
)
decay_rate
=
decay_rate
,
staircase
=
params
.
staircase
)
elif
decay_type
==
'piecewise_constant_with_warmup'
:
elif
decay_type
==
'piecewise_constant_with_warmup'
:
logging
.
info
(
'Using Piecewise constant decay with warmup. '
logging
.
info
(
'Using Piecewise constant decay with warmup. '
'Parameters: batch_size: %d, epoch_size: %d, '
'Parameters: batch_size: %d, epoch_size: %d, '
...
...
official/vision/image_classification/optimizer_factory_test.py
View file @
eb6fa0b2
...
@@ -35,9 +35,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -35,9 +35,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
(
'adam'
,
'adam'
,
0.
,
False
),
(
'adam'
,
'adam'
,
0.
,
False
),
(
'adamw'
,
'adamw'
,
0.
,
False
),
(
'adamw'
,
'adamw'
,
0.
,
False
),
(
'momentum_lookahead'
,
'momentum'
,
0.
,
True
),
(
'momentum_lookahead'
,
'momentum'
,
0.
,
True
),
(
'sgd_ema'
,
'sgd'
,
0.
001
,
False
),
(
'sgd_ema'
,
'sgd'
,
0.
999
,
False
),
(
'momentum_ema'
,
'momentum'
,
0.
001
,
False
),
(
'momentum_ema'
,
'momentum'
,
0.
999
,
False
),
(
'rmsprop_ema'
,
'rmsprop'
,
0.
001
,
False
))
(
'rmsprop_ema'
,
'rmsprop'
,
0.
999
,
False
))
def
test_optimizer
(
self
,
optimizer_name
,
moving_average_decay
,
lookahead
):
def
test_optimizer
(
self
,
optimizer_name
,
moving_average_decay
,
lookahead
):
"""Smoke test to be sure no syntax errors."""
"""Smoke test to be sure no syntax errors."""
params
=
{
params
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment