Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5ffcc5b6
Unverified
Commit
5ffcc5b6
authored
Jul 21, 2021
by
Anirudh Vegesana
Committed by
GitHub
Jul 21, 2021
Browse files
Merge branch 'purdue-yolo' into detection_generator_pr
parents
0b81a843
76e0c014
Changes
190
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
597 additions
and
88 deletions
+597
-88
README.md
README.md
+15
-0
official/README.md
official/README.md
+4
-18
official/common/distribute_utils.py
official/common/distribute_utils.py
+6
-3
official/common/distribute_utils_test.py
official/common/distribute_utils_test.py
+1
-1
official/common/flags.py
official/common/flags.py
+20
-2
official/core/actions.py
official/core/actions.py
+156
-0
official/core/actions_test.py
official/core/actions_test.py
+81
-0
official/core/base_task.py
official/core/base_task.py
+1
-1
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+8
-5
official/core/config_definitions.py
official/core/config_definitions.py
+14
-14
official/core/export_base.py
official/core/export_base.py
+1
-1
official/core/input_reader.py
official/core/input_reader.py
+65
-37
official/core/train_lib.py
official/core/train_lib.py
+18
-4
official/core/train_utils.py
official/core/train_utils.py
+55
-2
official/modeling/optimization/adafactor_optimizer.py
official/modeling/optimization/adafactor_optimizer.py
+20
-0
official/modeling/optimization/configs/learning_rate_config.py
...ial/modeling/optimization/configs/learning_rate_config.py
+8
-0
official/modeling/optimization/configs/optimization_config.py
...cial/modeling/optimization/configs/optimization_config.py
+1
-0
official/modeling/optimization/configs/optimizer_config.py
official/modeling/optimization/configs/optimizer_config.py
+19
-0
official/modeling/optimization/lr_schedule.py
official/modeling/optimization/lr_schedule.py
+69
-0
official/modeling/optimization/lr_schedule_test.py
official/modeling/optimization/lr_schedule_test.py
+35
-0
No files found.
README.md
View file @
5ffcc5b6
...
...
@@ -30,3 +30,18 @@ If you want to contribute, please review the [contribution guidelines](https://g
## License
[
Apache License 2.0
](
LICENSE
)
## Citing TensorFlow Model Garden
If you use TensorFlow Model Garden in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
official/README.md
View file @
5ffcc5b6
...
...
@@ -40,7 +40,7 @@ In the near future, we will add:
| Model | Reference (Paper) |
|-------|-------------------|
|
[
MNIST
](
vision/image_classification
)
| A basic model to classify digits from the
[
MNIST dataset
](
http://yann.lecun.com/exdb/mnist/
)
|
|
[
ResNet
](
vision/
image_classification
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/abs/1512.03385
)
|
|
[
ResNet
](
vision/
beta/MODEL_GARDEN.md
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/abs/1512.03385
)
|
|
[
ResNet-RS
](
vision/beta/MODEL_GARDEN.md
)
|
[
Revisiting ResNets: Improved Training and Scaling Strategies
](
https://arxiv.org/abs/2103.07579
)
|
|
[
EfficientNet
](
vision/image_classification
)
|
[
EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
](
https://arxiv.org/abs/1905.11946
)
|
...
...
@@ -48,10 +48,10 @@ In the near future, we will add:
| Model | Reference (Paper) |
|-------|-------------------|
|
[
RetinaNet
](
vision/
d
et
ection
)
|
[
Focal Loss for Dense Object Detection
](
https://arxiv.org/abs/1708.02002
)
|
|
[
Mask R-CNN
](
vision/
d
et
ection
)
|
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
|
|
[
RetinaNet
](
vision/
b
et
a/MODEL_GARDEN.md
)
|
[
Focal Loss for Dense Object Detection
](
https://arxiv.org/abs/1708.02002
)
|
|
[
Mask R-CNN
](
vision/
b
et
a/MODEL_GARDEN.md
)
|
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
|
|
[
ShapeMask
](
vision/detection
)
|
[
ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors
](
https://arxiv.org/abs/1904.03239
)
|
|
[
SpineNet
](
vision/
d
et
ection
)
|
[
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
](
https://arxiv.org/abs/1912.05027
)
|
|
[
SpineNet
](
vision/
b
et
a/MODEL_GARDEN.md
)
|
[
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
](
https://arxiv.org/abs/1912.05027
)
|
### Natural Language Processing
...
...
@@ -163,17 +163,3 @@ pip3 install tensorflow-text-nightly
## Contributions
If you want to contribute, please review the
[
contribution guidelines
](
https://github.com/tensorflow/models/wiki/How-to-contribute
)
.
## Citing TF Official Model Garden
To cite this repository:
```
@software{tfmodels2020github,
author = {Chen Chen and Xianzhi Du and Le Hou and Jaeyoun Kim and Jing Li and
Yeqing Li and Abdullah Rashwan and Fan Yang and Hongkun Yu},
title = {TensorFlow Official Model Garden},
url = {https://github.com/tensorflow/models/tree/master/official},
year = {2020},
}
```
official/common/distribute_utils.py
View file @
5ffcc5b6
...
...
@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
"off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
...
...
@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if
num_gpus
>
1
:
raise
ValueError
(
"When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`."
.
format
(
num_gpus
))
return
None
# Return the default distribution strategy.
return
tf
.
distribute
.
get_strategy
()
if
distribution_strategy
==
"tpu"
:
# When tpu_address is an empty string, we communicate with local TPUs.
...
...
official/common/distribute_utils_test.py
View file @
5ffcc5b6
...
...
@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase):
def
test_no_strategy
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'off'
)
self
.
assertIs
None
(
ds
)
self
.
assertIs
(
ds
,
tf
.
distribute
.
get_strategy
()
)
def
test_invalid_strategy
(
self
):
with
self
.
assertRaisesRegexp
(
...
...
official/common/flags.py
View file @
5ffcc5b6
...
...
@@ -18,9 +18,27 @@ from absl import flags
def
define_flags
():
"""Defines flags."""
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags
.
DEFINE_string
(
'experiment'
,
default
=
None
,
help
=
'The experiment type registered.'
)
'experiment'
,
default
=
None
,
help
=
'The experiment type registered, specifying an ExperimentConfig.'
)
flags
.
DEFINE_enum
(
'mode'
,
...
...
official/core/actions.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import
os
from
typing
import
List
import
gin
import
orbit
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.modeling
import
optimization
class
PruningActions
:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def
__init__
(
self
,
export_dir
:
str
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self
.
_optimizer
=
optimizer
self
.
update_pruning_step
=
tfmot
.
sparsity
.
keras
.
UpdatePruningStep
()
self
.
update_pruning_step
.
set_model
(
model
)
self
.
update_pruning_step
.
on_train_begin
()
self
.
pruning_summaries
=
tfmot
.
sparsity
.
keras
.
PruningSummaries
(
log_dir
=
export_dir
)
model
.
optimizer
=
optimizer
self
.
pruning_summaries
.
set_model
(
model
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
class
EMACheckpointing
:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def
__init__
(
self
,
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if
not
isinstance
(
optimizer
,
optimization
.
ExponentialMovingAverage
):
raise
ValueError
(
'Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action'
)
export_dir
=
os
.
path
.
join
(
export_dir
,
'ema_checkpoints'
)
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
export_dir
))
self
.
_optimizer
=
optimizer
self
.
_checkpoint
=
checkpoint
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
export_dir
,
max_to_keep
=
max_to_keep
,
checkpoint_name
=
'average_weights'
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
"""
self
.
_optimizer
.
swap_weights
()
self
.
_checkpoint_manager
.
save
(
checkpoint_number
=
self
.
_optimizer
.
iterations
)
self
.
_optimizer
.
swap_weights
()
@
gin
.
configurable
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets eval actions for TFM trainer."""
eval_actions
=
[]
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if
isinstance
(
trainer
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
eval_actions
.
append
(
EMACheckpointing
(
export_dir
=
model_dir
,
optimizer
=
trainer
.
optimizer
,
checkpoint
=
trainer
.
checkpoint
,
max_to_keep
=
params
.
trainer
.
max_to_keep
))
return
eval_actions
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
train_actions
=
[]
# Adds pruning callback actions.
if
hasattr
(
params
.
task
,
'pruning'
):
train_actions
.
append
(
PruningActions
(
export_dir
=
model_dir
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
return
train_actions
official/core/actions_test.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
actions
from
official.modeling
import
optimization
class
TestModel
(
tf
.
Module
):
def
__init__
(
self
):
self
.
value
=
tf
.
Variable
(
0
)
@
tf
.
function
(
input_signature
=
[])
def
__call__
(
self
):
return
self
.
value
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
create_tempdir
()
model
=
TestModel
()
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
optimizer
=
optimization
.
ExponentialMovingAverage
(
optimizer
,
trainable_weights_only
=
False
)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer
.
shadow_copy
(
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
# Changes model.value to 3, average value is still 0.
model
.
value
.
assign
(
3
)
# Checks model.value is 3
self
.
assertEqual
(
model
(),
3
)
ema_action
=
actions
.
EMACheckpointing
(
directory
,
optimizer
,
checkpoint
)
ema_action
({})
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
checkpoint
.
read
(
tf
.
train
.
latest_checkpoint
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/core/base_task.py
View file @
5ffcc5b6
...
...
@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
if
runtime_config
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
"float16"
,
...
...
official/core/base_trainer_test.py
View file @
5ffcc5b6
...
...
@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
},
})))
trainer
=
self
.
create_test_trainer
(
config
)
if
mixed_precision_dtype
!=
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
elif
mixed_precision_dtype
==
'float16'
and
loss_scale
is
None
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
else
:
if
mixed_precision_dtype
==
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
)
if
loss_scale
in
(
None
,
'dynamic'
):
self
.
assertTrue
(
trainer
.
optimizer
.
dynamic
)
else
:
self
.
assertFalse
(
trainer
.
optimizer
.
dynamic
)
self
.
assertEqual
(
trainer
.
optimizer
.
initial_scale
,
loss_scale
)
else
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
...
...
official/core/config_definitions.py
View file @
5ffcc5b6
...
...
@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating
a file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
input_path: The path to the input. It can be either (1) a str indicating a
file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
str, each of which is a file path/pattern or multiple file paths/patterns
separated by comma, or (4) a dictionary of the previous three approaches
for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
...
...
@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the
second
epoch, but it requires significant memory overhead.
re-reading from disk, re-decoding and re-parsing the example on the
second
epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
...
...
@@ -60,10 +61,9 @@ class DataConfig(base_config.Config):
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
tf_data_service_job_name: The name of the tf.data service job. This argument
makes it possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
...
...
@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
performance.
seed: An optional seed to use for deterministic shuffling/preprocessing.
"""
input_path
:
Union
[
Sequence
[
str
],
str
]
=
""
input_path
:
Union
[
Sequence
[
str
],
str
,
base_config
.
Config
]
=
""
tfds_name
:
str
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
...
...
official/core/export_base.py
View file @
5ffcc5b6
...
...
@@ -82,7 +82,7 @@ def export(export_module: ExportModule,
The savedmodel directory path.
"""
ckpt_dir_or_file
=
checkpoint_path
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
if
ckpt_dir_or_file
is
not
None
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
ckpt_dir_or_file
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
...
...
official/core/input_reader.py
View file @
5ffcc5b6
...
...
@@ -14,7 +14,7 @@
"""A common dataset reader."""
import
random
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
,
Dict
,
Sequence
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -45,6 +45,7 @@ class InputReader:
params
:
cfg
.
DataConfig
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
combine_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
sample_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
transform_and_batch_fn
:
Optional
[
Callable
[
...
...
@@ -59,6 +60,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
...
...
@@ -78,9 +82,22 @@ class InputReader:
raise
ValueError
(
'At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.'
%
(
params
.
input_path
,
params
.
tfds_name
))
if
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
)
and
combine_fn
is
None
:
raise
ValueError
(
'A `combine_fn` is required if the `input_path` is a dictionary.'
)
self
.
_tfds_builder
=
None
self
.
_matched_files
=
[]
self
.
_matched_files
=
None
if
params
.
input_path
:
# we want to combine / mix datasets
if
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
):
self
.
_matched_files
=
{}
for
k
,
v
in
params
.
input_path
.
as_dict
().
items
():
self
.
_matched_files
[
k
]
=
self
.
_match_files
(
v
)
# single dataset
else
:
self
.
_matched_files
=
self
.
_match_files
(
params
.
input_path
)
else
:
# Read dataset from TFDS.
...
...
@@ -106,6 +123,7 @@ class InputReader:
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_combine_fn
=
combine_fn
self
.
_sample_fn
=
sample_fn
self
.
_parser_fn
=
parser_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
...
...
@@ -131,7 +149,7 @@ class InputReader:
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
def
_match_files
(
self
,
input_path
:
str
)
->
List
[
str
]:
def
_match_files
(
self
,
input_path
:
Union
[
Sequence
[
str
],
str
]
)
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
...
...
@@ -195,8 +213,8 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
...
...
@@ -231,8 +249,8 @@ class InputReader:
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
...
...
@@ -281,42 +299,53 @@ class InputReader:
def
_read_decode_and_parse_dataset
(
self
,
matched_files
:
List
[
str
],
matched_files
:
Union
[
Dict
[
str
,
List
[
str
]],
List
[
str
]
],
dataset_fn
,
batch_size
:
int
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
tfds_builder
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
if
tfds_builder
:
dataset
=
self
.
_read_tfds
(
input_context
)
elif
len
(
matched_files
)
>
1
:
if
input_context
and
(
len
(
matched_files
)
<
input_context
.
num_input_pipelines
):
def
_files_to_dataset
(
files
:
List
[
str
])
->
tf
.
data
.
Dataset
:
if
len
(
files
)
>
1
:
if
input_context
and
(
len
(
files
)
<
input_context
.
num_input_pipelines
):
logging
.
warn
(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
,
len
(
matched_files
),
input_context
.
num_input_pipelines
)
dataset
=
self
.
_read_files_then_shard
(
matched_files
,
dataset_fn
,
input_context
)
'Please consider sharding your data into more files.'
,
len
(
files
),
input_context
.
num_input_pipelines
)
return
self
.
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
)
else
:
dataset
=
self
.
_shard_files_then_read
(
matched_files
,
dataset_fn
,
input_context
)
elif
len
(
matched_files
)
==
1
:
dataset
=
self
.
_read_files_then_shard
(
matched_files
,
dataset_fn
,
input_context
)
return
self
.
_shard_files_then_read
(
files
,
dataset_fn
,
input_context
)
elif
len
(
files
)
==
1
:
return
self
.
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
)
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `
matched_
files`.'
)
'there is also no `files`.'
)
def
_shuffle_and_decode
(
ds
):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
ds
=
ds
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
# Decode
ds
=
_maybe_map_fn
(
ds
,
self
.
_decoder_fn
)
return
ds
if
tfds_builder
:
dataset
=
self
.
_read_tfds
(
input_context
)
dataset
=
_shuffle_and_decode
(
dataset
)
elif
isinstance
(
matched_files
,
(
list
,
tuple
)):
dataset
=
_files_to_dataset
(
matched_files
)
dataset
=
_shuffle_and_decode
(
dataset
)
elif
isinstance
(
matched_files
,
dict
):
datasets
=
{}
for
k
,
fs
in
matched_files
.
items
():
datasets
[
k
]
=
_files_to_dataset
(
fs
)
datasets
[
k
]
=
_shuffle_and_decode
(
datasets
[
k
])
dataset
=
self
.
_combine_fn
(
datasets
)
else
:
raise
ValueError
(
'`matched_files` should be a list or dict.'
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
...
...
@@ -333,8 +362,7 @@ class InputReader:
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_context
else
batch_size
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
return
dataset
...
...
official/core/train_lib.py
View file @
5ffcc5b6
...
...
@@ -15,13 +15,15 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
Mapping
,
Tuple
,
Optional
from
typing
import
Any
,
Mapping
,
Optional
,
Tuple
# Import libraries
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
actions
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
...
...
@@ -38,7 +40,8 @@ def run_experiment(
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
...
...
@@ -54,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
...
...
@@ -73,6 +78,8 @@ def run_experiment(
params
,
model_dir
))
if
trainer
.
checkpoint
:
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
...
...
@@ -85,7 +92,7 @@ def run_experiment(
else
:
checkpoint_manager
=
None
controller
=
orbit
.
C
ontroller
(
controller
=
c
ontroller
_cls
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
...
...
@@ -97,7 +104,9 @@ def run_experiment(
params
.
trainer
.
validation_summary_subdir
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
)
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
...
...
@@ -129,6 +138,11 @@ def run_experiment(
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
run_post_eval
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
trainer
.
evaluate
(
...
...
official/core/train_utils.py
View file @
5ffcc5b6
...
...
@@ -17,7 +17,7 @@ import copy
import
json
import
os
import
pprint
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
absl
import
logging
import
dataclasses
...
...
@@ -25,6 +25,9 @@ import gin
import
orbit
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.framework.convert_to_constants
import
convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
...
...
@@ -241,6 +244,9 @@ class ParseConfigOptions:
def
parse_configuration
(
flags_obj
,
lock_return
=
True
,
print_return
=
True
):
"""Parses ExperimentConfig from flags."""
if
flags_obj
.
experiment
is
None
:
raise
ValueError
(
'The flag --experiment must be specified.'
)
# 1. Get the default config from the registered experiment.
params
=
exp_factory
.
get_exp_config
(
flags_obj
.
experiment
)
...
...
@@ -285,7 +291,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
if
print_return
:
pp
=
pprint
.
PrettyPrinter
()
logging
.
info
(
'Final experiment parameters:
%s'
,
logging
.
info
(
'Final experiment parameters:
\n
%s'
,
pp
.
pformat
(
params
.
as_dict
()))
return
params
...
...
@@ -294,6 +300,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
def
serialize_config
(
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
):
"""Serializes and saves the experiment config."""
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
params_save_path
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
logging
.
info
(
'Saving experiment configuration to %s'
,
params_save_path
)
tf
.
io
.
gfile
.
makedirs
(
model_dir
)
...
...
@@ -388,3 +396,48 @@ def try_count_params(model: tf.keras.Model):
'train step already reached before this run.'
)
return
None
return
None
def
try_count_flops
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
inputs_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if
hasattr
(
model
,
'inputs'
):
try
:
# Get input shape and set batch size to 1.
if
model
.
inputs
:
inputs
=
[
tf
.
TensorSpec
([
1
]
+
input
.
shape
[
1
:],
input
.
dtype
)
for
input
in
model
.
inputs
]
concrete_func
=
tf
.
function
(
model
).
get_concrete_function
(
inputs
)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else
:
concrete_func
=
tf
.
function
(
model
.
call
).
get_concrete_function
(
**
inputs_kwargs
)
frozen_func
,
_
=
convert_variables_to_constants_v2_as_graph
(
concrete_func
)
# Calculate FLOPs.
run_meta
=
tf
.
compat
.
v1
.
RunMetadata
()
opts
=
tf
.
compat
.
v1
.
profiler
.
ProfileOptionBuilder
.
float_operation
()
opts
[
'output'
]
=
'none'
flops
=
tf
.
compat
.
v1
.
profiler
.
profile
(
graph
=
frozen_func
.
graph
,
run_meta
=
run_meta
,
options
=
opts
)
return
flops
.
total_float_ops
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
info
(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.'
,
e
)
return
None
return
None
official/modeling/optimization/adafactor_optimizer.py
0 → 100644
View file @
5ffcc5b6
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adafactor optimizer.
A new optimizer that will be open sourced soon.
"""
# pylint: disable=invalid-name, represents an unimplemented class definition.
Adafactor
=
"Unimplemented"
official/modeling/optimization/configs/learning_rate_config.py
View file @
5ffcc5b6
...
...
@@ -56,10 +56,12 @@ class StepwiseLrConfig(base_config.Config):
values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
end] -> values[n+1] Defaults to None.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name
:
str
=
'PiecewiseConstantDecay'
boundaries
:
Optional
[
List
[
int
]]
=
None
values
:
Optional
[
List
[
float
]]
=
None
offset
:
int
=
0
@
dataclasses
.
dataclass
...
...
@@ -76,12 +78,14 @@ class ExponentialLrConfig(base_config.Config):
decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name
:
str
=
'ExponentialDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
decay_steps
:
Optional
[
int
]
=
None
decay_rate
:
Optional
[
float
]
=
None
staircase
:
Optional
[
bool
]
=
None
offset
:
int
=
0
@
dataclasses
.
dataclass
...
...
@@ -99,6 +103,7 @@ class PolynomialLrConfig(base_config.Config):
power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
Defaults to False.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name
:
str
=
'PolynomialDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
...
...
@@ -106,6 +111,7 @@ class PolynomialLrConfig(base_config.Config):
end_learning_rate
:
float
=
0.0001
power
:
float
=
1.0
cycle
:
bool
=
False
offset
:
int
=
0
@
dataclasses
.
dataclass
...
...
@@ -122,11 +128,13 @@ class CosineLrConfig(base_config.Config):
to None.
alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate.
offset: An int. The offset applied to steps. Defaults to 0.
"""
name
:
str
=
'CosineDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
decay_steps
:
Optional
[
int
]
=
None
alpha
:
float
=
0.0
offset
:
int
=
0
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/configs/optimization_config.py
View file @
5ffcc5b6
...
...
@@ -52,6 +52,7 @@ class OptimizerConfig(oneof.OneOfConfig):
lars
:
opt_cfg
.
LARSConfig
=
opt_cfg
.
LARSConfig
()
adagrad
:
opt_cfg
.
AdagradConfig
=
opt_cfg
.
AdagradConfig
()
slide
:
opt_cfg
.
SLIDEConfig
=
opt_cfg
.
SLIDEConfig
()
adafactor
:
opt_cfg
.
AdafactorConfig
=
opt_cfg
.
AdafactorConfig
()
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/configs/optimizer_config.py
View file @
5ffcc5b6
...
...
@@ -247,3 +247,22 @@ class SLIDEConfig(BaseOptimizerConfig):
do_gradient_rescaling
:
bool
=
True
norm_type
:
str
=
"layer"
ratio_clip_norm
:
float
=
1e5
@
dataclasses
.
dataclass
class
AdafactorConfig
(
BaseOptimizerConfig
):
"""Configuration for Adafactor optimizer.
The attributes for this class matches the arguments of the Adafactor
implementation.
"""
name
:
str
=
"Adafactor"
factored
:
bool
=
True
multiply_by_parameter_scale
:
bool
=
True
beta1
:
Optional
[
float
]
=
None
decay_rate
:
float
=
0.8
step_offset
:
int
=
0
clipping_threshold
:
float
=
1.0
min_dim_size_to_factor
:
int
=
128
epsilon1
:
float
=
1e-30
epsilon2
:
float
=
1e-3
official/modeling/optimization/lr_schedule.py
View file @
5ffcc5b6
...
...
@@ -19,6 +19,75 @@ from typing import Mapping, Any, Union, Optional
import
tensorflow
as
tf
def
_make_offset_wrapper
(
new_class_name
:
str
,
base_lr_class
):
"""Generates a offset wrapper of learning rate schedule.
It will returns a subclass of the the `base_lr_class`, the subclass takes an
`offset` argument in the constructor. When the new class instance is called,
the behavior is:
new_class_object(step) = base_lr_class_object(step - offset)
Example:
CosineDecayWithOffset = _make_offset_wrapper(
'CosineDecayWithOffset', tf.keras.experimental.CosineDecay)
# Use the lr:
lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
decay_steps=1000)
lr(101) # equals to tf.keras.experimental.CosineDecay(...)(101-100)
Args:
new_class_name: the name of the new class.
base_lr_class: the base learning rate schedule class. Should be subclass of
tf.keras.optimizers.schedules.LearningRateSchedule
Returns:
A new class (subclass of the base_lr_class) that can take an offset.
"""
assert
issubclass
(
base_lr_class
,
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
),
(
"base_lr_class should be subclass of keras "
f
"LearningRateSchedule, got
{
base_lr_class
}
"
)
# pylint: disable=protected-access,pointless-statement
def
offset_learning_rate_init
(
self
,
offset
=
0
,
**
kwargs
):
"""Construct learning rate schedule object.
When this object is called, its behavior is
self.__call__(step) == base_lr_class.__call__(step - offset)
Args:
self: this object.
offset: The offset when computing the learning rate schedule.
**kwargs: Pass through to base learning rate class constructor.
"""
base_lr_class
.
__init__
(
self
,
**
kwargs
)
self
.
_offset
=
offset
def
offset_learning_rate_call
(
self
,
step
):
step
=
tf
.
cast
(
step
-
self
.
_offset
,
tf
.
float32
)
return
base_lr_class
.
__call__
(
self
,
step
)
# pylint: enable=protected-access,pointless-statement
return
type
(
new_class_name
,
(
base_lr_class
,),
{
"base_lr_class"
:
base_lr_class
,
"__init__"
:
offset_learning_rate_init
,
"__call__"
:
offset_learning_rate_call
})
PiecewiseConstantDecayWithOffset
=
_make_offset_wrapper
(
"PiecewiseConstantDecayWithOffset"
,
tf
.
keras
.
optimizers
.
schedules
.
PiecewiseConstantDecay
)
PolynomialDecayWithOffset
=
_make_offset_wrapper
(
"PolynomialDecayWithOffset"
,
tf
.
keras
.
optimizers
.
schedules
.
PolynomialDecay
)
ExponentialDecayWithOffset
=
_make_offset_wrapper
(
"ExponentialDecayWithOffset"
,
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
)
CosineDecayWithOffset
=
_make_offset_wrapper
(
"CosineDecayWithOffset"
,
tf
.
keras
.
experimental
.
CosineDecay
)
class
LinearWarmup
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Linear warmup schedule."""
...
...
official/modeling/optimization/lr_schedule_test.py
View file @
5ffcc5b6
...
...
@@ -70,5 +70,40 @@ class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
class
OffsetLearningRateTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
dict
(
class_name
=
lr_schedule
.
PiecewiseConstantDecayWithOffset
),
dict
(
class_name
=
lr_schedule
.
PolynomialDecayWithOffset
),
dict
(
class_name
=
lr_schedule
.
ExponentialDecayWithOffset
),
dict
(
class_name
=
lr_schedule
.
CosineDecayWithOffset
),
)
def
test_generated_docstring
(
self
,
class_name
):
self
.
assertNotEmpty
(
class_name
.
__init__
.
__doc__
)
@
parameterized
.
parameters
(
dict
(
class_name
=
lr_schedule
.
PiecewiseConstantDecayWithOffset
,
kwarg
=
dict
(
boundaries
=
[
50
,
80
],
values
=
[
1.0
,
0.5
,
0.1
])),
dict
(
class_name
=
lr_schedule
.
PolynomialDecayWithOffset
,
kwarg
=
dict
(
initial_learning_rate
=
1.0
,
decay_steps
=
100
)),
dict
(
class_name
=
lr_schedule
.
ExponentialDecayWithOffset
,
kwarg
=
dict
(
initial_learning_rate
=
1.0
,
decay_steps
=
100
,
decay_rate
=
0.5
)),
dict
(
class_name
=
lr_schedule
.
CosineDecayWithOffset
,
kwarg
=
dict
(
initial_learning_rate
=
1.0
,
decay_steps
=
100
)),
)
def
test_offset
(
self
,
class_name
,
kwarg
):
offset
=
10
offset_lr
=
class_name
(
offset
=
offset
,
**
kwarg
)
base_lr
=
class_name
.
base_lr_class
(
**
kwarg
)
self
.
assertIsInstance
(
offset_lr
,
class_name
)
for
step
in
range
(
10
,
101
,
10
):
self
.
assertEqual
(
offset_lr
(
step
),
base_lr
(
step
-
offset
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
3
4
5
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment