Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
54602a66
Commit
54602a66
authored
Mar 25, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Mar 25, 2020
Browse files
Internal change
PiperOrigin-RevId: 302937425
parent
7a257585
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1509 additions
and
0 deletions
+1509
-0
official/vision/image_classification/efficientnet/efficientnet_model.py
...n/image_classification/efficientnet/efficientnet_model.py
+503
-0
official/vision/image_classification/learning_rate.py
official/vision/image_classification/learning_rate.py
+120
-0
official/vision/image_classification/learning_rate_test.py
official/vision/image_classification/learning_rate_test.py
+90
-0
official/vision/image_classification/optimizer_factory.py
official/vision/image_classification/optimizer_factory.py
+161
-0
official/vision/image_classification/optimizer_factory_test.py
...ial/vision/image_classification/optimizer_factory_test.py
+115
-0
official/vision/image_classification/preprocessing.py
official/vision/image_classification/preprocessing.py
+391
-0
official/vision/image_classification/resnet/README.md
official/vision/image_classification/resnet/README.md
+129
-0
No files found.
official/vision/image_classification/efficientnet/efficientnet_model.py
0 → 100644
View file @
54602a66
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions for EfficientNet model.
[1] Mingxing Tan, Quoc V. Le
EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
ICML'19, https://arxiv.org/abs/1905.11946
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
os
from
typing
import
Any
,
Dict
,
Optional
,
Text
,
Tuple
from
absl
import
logging
from
dataclasses
import
dataclass
import
tensorflow.compat.v2
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.vision.image_classification
import
preprocessing
from
official.vision.image_classification.efficientnet
import
common_modules
@
dataclass
class
BlockConfig
(
base_config
.
Config
):
"""Config for a single MB Conv Block."""
input_filters
:
int
=
0
output_filters
:
int
=
0
kernel_size
:
int
=
3
num_repeat
:
int
=
1
expand_ratio
:
int
=
1
strides
:
Tuple
[
int
,
int
]
=
(
1
,
1
)
se_ratio
:
Optional
[
float
]
=
None
id_skip
:
bool
=
True
fused_conv
:
bool
=
False
conv_type
:
str
=
'depthwise'
@
dataclass
class
ModelConfig
(
base_config
.
Config
):
"""Default Config for Efficientnet-B0."""
width_coefficient
:
float
=
1.0
depth_coefficient
:
float
=
1.0
resolution
:
int
=
224
dropout_rate
:
float
=
0.2
blocks
:
Tuple
[
BlockConfig
,
...]
=
(
# (input_filters, output_filters, kernel_size, num_repeat,
# expand_ratio, strides, se_ratio)
# pylint: disable=bad-whitespace
BlockConfig
.
from_args
(
32
,
16
,
3
,
1
,
1
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
16
,
24
,
3
,
2
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
24
,
40
,
5
,
2
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
40
,
80
,
3
,
3
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
80
,
112
,
5
,
3
,
6
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
112
,
192
,
5
,
4
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
192
,
320
,
3
,
1
,
6
,
(
1
,
1
),
0.25
),
# pylint: enable=bad-whitespace
)
stem_base_filters
:
int
=
32
top_base_filters
:
int
=
1280
activation
:
str
=
'simple_swish'
batch_norm
:
str
=
'default'
bn_momentum
:
float
=
0.99
bn_epsilon
:
float
=
1e-3
# While the original implementation used a weight decay of 1e-5,
# tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras
weight_decay
:
float
=
5e-6
drop_connect_rate
:
float
=
0.2
depth_divisor
:
int
=
8
min_depth
:
Optional
[
int
]
=
None
use_se
:
bool
=
True
input_channels
:
int
=
3
num_classes
:
int
=
1000
model_name
:
str
=
'efficientnet'
rescale_input
:
bool
=
True
data_format
:
str
=
'channels_last'
dtype
:
str
=
'float32'
MODEL_CONFIGS
=
{
# (width, depth, resolution, dropout)
'efficientnet-b0'
:
ModelConfig
.
from_args
(
1.0
,
1.0
,
224
,
0.2
),
'efficientnet-b1'
:
ModelConfig
.
from_args
(
1.0
,
1.1
,
240
,
0.2
),
'efficientnet-b2'
:
ModelConfig
.
from_args
(
1.1
,
1.2
,
260
,
0.3
),
'efficientnet-b3'
:
ModelConfig
.
from_args
(
1.2
,
1.4
,
300
,
0.3
),
'efficientnet-b4'
:
ModelConfig
.
from_args
(
1.4
,
1.8
,
380
,
0.4
),
'efficientnet-b5'
:
ModelConfig
.
from_args
(
1.6
,
2.2
,
456
,
0.4
),
'efficientnet-b6'
:
ModelConfig
.
from_args
(
1.8
,
2.6
,
528
,
0.5
),
'efficientnet-b7'
:
ModelConfig
.
from_args
(
2.0
,
3.1
,
600
,
0.5
),
}
CONV_KERNEL_INITIALIZER
=
{
'class_name'
:
'VarianceScaling'
,
'config'
:
{
'scale'
:
2.0
,
'mode'
:
'fan_out'
,
# Note: this is a truncated normal distribution
'distribution'
:
'normal'
}
}
DENSE_KERNEL_INITIALIZER
=
{
'class_name'
:
'VarianceScaling'
,
'config'
:
{
'scale'
:
1
/
3.0
,
'mode'
:
'fan_out'
,
'distribution'
:
'uniform'
}
}
def
round_filters
(
filters
:
int
,
config
:
ModelConfig
)
->
int
:
"""Round number of filters based on width coefficient."""
width_coefficient
=
config
.
width_coefficient
min_depth
=
config
.
min_depth
divisor
=
config
.
depth_divisor
orig_filters
=
filters
if
not
width_coefficient
:
return
filters
filters
*=
width_coefficient
min_depth
=
min_depth
or
divisor
new_filters
=
max
(
min_depth
,
int
(
filters
+
divisor
/
2
)
//
divisor
*
divisor
)
# Make sure that round down does not go down by more than 10%.
if
new_filters
<
0.9
*
filters
:
new_filters
+=
divisor
logging
.
info
(
'round_filter input=%s output=%s'
,
orig_filters
,
new_filters
)
return
int
(
new_filters
)
def
round_repeats
(
repeats
:
int
,
depth_coefficient
:
float
)
->
int
:
"""Round number of repeats based on depth coefficient."""
return
int
(
math
.
ceil
(
depth_coefficient
*
repeats
))
def
conv2d_block
(
inputs
:
tf
.
Tensor
,
conv_filters
:
Optional
[
int
],
config
:
ModelConfig
,
kernel_size
:
Any
=
(
1
,
1
),
strides
:
Any
=
(
1
,
1
),
use_batch_norm
:
bool
=
True
,
use_bias
:
bool
=
False
,
activation
:
Any
=
None
,
depthwise
:
bool
=
False
,
name
:
Text
=
None
):
"""A conv2d followed by batch norm and an activation."""
batch_norm
=
common_modules
.
get_batch_norm
(
config
.
batch_norm
)
bn_momentum
=
config
.
bn_momentum
bn_epsilon
=
config
.
bn_epsilon
data_format
=
config
.
data_format
weight_decay
=
config
.
weight_decay
name
=
name
or
''
# Collect args based on what kind of conv2d block is desired
init_kwargs
=
{
'kernel_size'
:
kernel_size
,
'strides'
:
strides
,
'use_bias'
:
use_bias
,
'padding'
:
'same'
,
'name'
:
name
+
'_conv2d'
,
'kernel_regularizer'
:
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
'bias_regularizer'
:
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
}
if
depthwise
:
conv2d
=
tf
.
keras
.
layers
.
DepthwiseConv2D
init_kwargs
.
update
({
'depthwise_initializer'
:
CONV_KERNEL_INITIALIZER
})
else
:
conv2d
=
tf
.
keras
.
layers
.
Conv2D
init_kwargs
.
update
({
'filters'
:
conv_filters
,
'kernel_initializer'
:
CONV_KERNEL_INITIALIZER
})
x
=
conv2d
(
**
init_kwargs
)(
inputs
)
if
use_batch_norm
:
bn_axis
=
1
if
data_format
==
'channels_first'
else
-
1
x
=
batch_norm
(
axis
=
bn_axis
,
momentum
=
bn_momentum
,
epsilon
=
bn_epsilon
,
name
=
name
+
'_bn'
)(
x
)
if
activation
is
not
None
:
x
=
tf
.
keras
.
layers
.
Activation
(
activation
,
name
=
name
+
'_activation'
)(
x
)
return
x
def
mb_conv_block
(
inputs
:
tf
.
Tensor
,
block
:
BlockConfig
,
config
:
ModelConfig
,
prefix
:
Text
=
None
):
"""Mobile Inverted Residual Bottleneck.
Args:
inputs: the Keras input to the block
block: BlockConfig, arguments to create a Block
config: ModelConfig, a set of model parameters
prefix: prefix for naming all layers
Returns:
the output of the block
"""
use_se
=
config
.
use_se
activation
=
tf_utils
.
get_activation
(
config
.
activation
)
drop_connect_rate
=
config
.
drop_connect_rate
data_format
=
config
.
data_format
use_depthwise
=
block
.
conv_type
!=
'no_depthwise'
prefix
=
prefix
or
''
filters
=
block
.
input_filters
*
block
.
expand_ratio
x
=
inputs
if
block
.
fused_conv
:
# If we use fused mbconv, skip expansion and use regular conv.
x
=
conv2d_block
(
x
,
filters
,
config
,
kernel_size
=
block
.
kernel_size
,
strides
=
block
.
strides
,
activation
=
activation
,
name
=
prefix
+
'fused'
)
else
:
if
block
.
expand_ratio
!=
1
:
# Expansion phase
kernel_size
=
(
1
,
1
)
if
use_depthwise
else
(
3
,
3
)
x
=
conv2d_block
(
x
,
filters
,
config
,
kernel_size
=
kernel_size
,
activation
=
activation
,
name
=
prefix
+
'expand'
)
# Depthwise Convolution
if
use_depthwise
:
x
=
conv2d_block
(
x
,
conv_filters
=
None
,
config
=
config
,
kernel_size
=
block
.
kernel_size
,
strides
=
block
.
strides
,
activation
=
activation
,
depthwise
=
True
,
name
=
prefix
+
'depthwise'
)
# Squeeze and Excitation phase
if
use_se
:
assert
block
.
se_ratio
is
not
None
assert
0
<
block
.
se_ratio
<=
1
num_reduced_filters
=
max
(
1
,
int
(
block
.
input_filters
*
block
.
se_ratio
))
if
data_format
==
'channels_first'
:
se_shape
=
(
filters
,
1
,
1
)
else
:
se_shape
=
(
1
,
1
,
filters
)
se
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
prefix
+
'se_squeeze'
)(
x
)
se
=
tf
.
keras
.
layers
.
Reshape
(
se_shape
,
name
=
prefix
+
'se_reshape'
)(
se
)
se
=
conv2d_block
(
se
,
num_reduced_filters
,
config
,
use_bias
=
True
,
use_batch_norm
=
False
,
activation
=
activation
,
name
=
prefix
+
'se_reduce'
)
se
=
conv2d_block
(
se
,
filters
,
config
,
use_bias
=
True
,
use_batch_norm
=
False
,
activation
=
'sigmoid'
,
name
=
prefix
+
'se_expand'
)
x
=
tf
.
keras
.
layers
.
multiply
([
x
,
se
],
name
=
prefix
+
'se_excite'
)
# Output phase
x
=
conv2d_block
(
x
,
block
.
output_filters
,
config
,
activation
=
None
,
name
=
prefix
+
'project'
)
# Add identity so that quantization-aware training can insert quantization
# ops correctly.
x
=
tf
.
keras
.
layers
.
Activation
(
tf_utils
.
get_activation
(
'identity'
),
name
=
prefix
+
'id'
)(
x
)
if
(
block
.
id_skip
and
all
(
s
==
1
for
s
in
block
.
strides
)
and
block
.
input_filters
==
block
.
output_filters
):
if
drop_connect_rate
and
drop_connect_rate
>
0
:
# Apply dropconnect
# The only difference between dropout and dropconnect in TF is scaling by
# drop_connect_rate during training. See:
# https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
x
=
tf
.
keras
.
layers
.
Dropout
(
drop_connect_rate
,
noise_shape
=
(
None
,
1
,
1
,
1
),
name
=
prefix
+
'drop'
)(
x
)
x
=
tf
.
keras
.
layers
.
add
([
x
,
inputs
],
name
=
prefix
+
'add'
)
return
x
def
efficientnet
(
image_input
:
tf
.
keras
.
layers
.
Input
,
config
:
ModelConfig
):
"""Creates an EfficientNet graph given the model parameters.
This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
Args:
image_input: the input batch of images
config: the model config
Returns:
the output of efficientnet
"""
depth_coefficient
=
config
.
depth_coefficient
blocks
=
config
.
blocks
stem_base_filters
=
config
.
stem_base_filters
top_base_filters
=
config
.
top_base_filters
activation
=
tf_utils
.
get_activation
(
config
.
activation
)
dropout_rate
=
config
.
dropout_rate
drop_connect_rate
=
config
.
drop_connect_rate
num_classes
=
config
.
num_classes
input_channels
=
config
.
input_channels
rescale_input
=
config
.
rescale_input
data_format
=
config
.
data_format
dtype
=
config
.
dtype
weight_decay
=
config
.
weight_decay
x
=
image_input
if
rescale_input
:
x
=
preprocessing
.
normalize_images
(
x
,
num_channels
=
input_channels
,
dtype
=
dtype
,
data_format
=
data_format
)
# Build stem
x
=
conv2d_block
(
x
,
round_filters
(
stem_base_filters
,
config
),
config
,
kernel_size
=
[
3
,
3
],
strides
=
[
2
,
2
],
activation
=
activation
,
name
=
'stem'
)
# Build blocks
num_blocks_total
=
sum
(
block
.
num_repeat
for
block
in
blocks
)
block_num
=
0
for
stack_idx
,
block
in
enumerate
(
blocks
):
assert
block
.
num_repeat
>
0
# Update block input and output filters based on depth multiplier
block
=
block
.
replace
(
input_filters
=
round_filters
(
block
.
input_filters
,
config
),
output_filters
=
round_filters
(
block
.
output_filters
,
config
),
num_repeat
=
round_repeats
(
block
.
num_repeat
,
depth_coefficient
))
# The first block needs to take care of stride and filter size increase
drop_rate
=
drop_connect_rate
*
float
(
block_num
)
/
num_blocks_total
config
=
config
.
replace
(
drop_connect_rate
=
drop_rate
)
block_prefix
=
'stack_{}/block_0/'
.
format
(
stack_idx
)
x
=
mb_conv_block
(
x
,
block
,
config
,
block_prefix
)
block_num
+=
1
if
block
.
num_repeat
>
1
:
block
=
block
.
replace
(
input_filters
=
block
.
output_filters
,
strides
=
[
1
,
1
]
)
for
block_idx
in
range
(
block
.
num_repeat
-
1
):
drop_rate
=
drop_connect_rate
*
float
(
block_num
)
/
num_blocks_total
config
=
config
.
replace
(
drop_connect_rate
=
drop_rate
)
block_prefix
=
'stack_{}/block_{}/'
.
format
(
stack_idx
,
block_idx
+
1
)
x
=
mb_conv_block
(
x
,
block
,
config
,
prefix
=
block_prefix
)
block_num
+=
1
# Build top
x
=
conv2d_block
(
x
,
round_filters
(
top_base_filters
,
config
),
config
,
activation
=
activation
,
name
=
'top'
)
# Build classifier
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
'top_pool'
)(
x
)
if
dropout_rate
and
dropout_rate
>
0
:
x
=
tf
.
keras
.
layers
.
Dropout
(
dropout_rate
,
name
=
'top_dropout'
)(
x
)
x
=
tf
.
keras
.
layers
.
Dense
(
num_classes
,
kernel_initializer
=
DENSE_KERNEL_INITIALIZER
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
name
=
'logits'
)(
x
)
x
=
tf
.
keras
.
layers
.
Activation
(
'softmax'
,
name
=
'probs'
)(
x
)
return
x
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
EfficientNet
(
tf
.
keras
.
Model
):
"""Wrapper class for an EfficientNet Keras model.
Contains helper methods to build, manage, and save metadata about the model.
"""
def
__init__
(
self
,
config
:
ModelConfig
=
None
,
overrides
:
Dict
[
Text
,
Any
]
=
None
):
"""Create an EfficientNet model.
Args:
config: (optional) the main model parameters to create the model
overrides: (optional) a dict containing keys that can override
config
"""
overrides
=
overrides
or
{}
config
=
config
or
ModelConfig
()
self
.
config
=
config
.
replace
(
**
overrides
)
input_channels
=
self
.
config
.
input_channels
model_name
=
self
.
config
.
model_name
input_shape
=
(
None
,
None
,
input_channels
)
# Should handle any size image
image_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
input_shape
)
output
=
efficientnet
(
image_input
,
self
.
config
)
# Cast to float32 in case we have a different model dtype
output
=
tf
.
cast
(
output
,
tf
.
float32
)
logging
.
info
(
'Building model %s with params %s'
,
model_name
,
self
.
config
)
super
(
EfficientNet
,
self
).
__init__
(
inputs
=
image_input
,
outputs
=
output
,
name
=
model_name
)
@
classmethod
def
from_name
(
cls
,
model_name
:
Text
,
model_weights_path
:
Text
=
None
,
copy_to_local
:
bool
=
False
,
overrides
:
Dict
[
Text
,
Any
]
=
None
):
"""Construct an EfficientNet model from a predefined model name.
E.g., `EfficientNet.from_name('efficientnet-b0')`.
Args:
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
copy_to_local: copy the weights to a local tmp dir
overrides: (optional) a dict containing keys that can override config
Returns:
A constructed EfficientNet instance.
"""
model_configs
=
dict
(
MODEL_CONFIGS
)
overrides
=
dict
(
overrides
)
if
overrides
else
{}
# One can define their own custom models if necessary
model_configs
.
update
(
overrides
.
pop
(
'model_config'
,
{}))
if
model_name
not
in
model_configs
:
raise
ValueError
(
'Unknown model name {}'
.
format
(
model_name
))
config
=
model_configs
[
model_name
]
model
=
cls
(
config
=
config
,
overrides
=
overrides
)
if
model_weights_path
:
if
copy_to_local
:
tmp_file
=
os
.
path
.
join
(
'/tmp'
,
model_name
+
'.h5'
)
model_weights_file
=
os
.
path
.
join
(
model_weights_path
,
'model.h5'
)
tf
.
io
.
gfile
.
copy
(
model_weights_file
,
tmp_file
,
overwrite
=
True
)
model_weights_path
=
tmp_file
model
.
load_weights
(
model_weights_path
)
return
model
official/vision/image_classification/learning_rate.py
0 → 100644
View file @
54602a66
# Lint as: python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Learning rate utilities for vision tasks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Any
,
List
,
Mapping
import
tensorflow.compat.v2
as
tf
BASE_LEARNING_RATE
=
0.1
class
WarmupDecaySchedule
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""A wrapper for LearningRateSchedule that includes warmup steps."""
def
__init__
(
self
,
lr_schedule
:
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
warmup_steps
:
int
):
"""Add warmup decay to a learning rate schedule.
Args:
lr_schedule: base learning rate scheduler
warmup_steps: number of warmup steps
"""
super
(
WarmupDecaySchedule
,
self
).
__init__
()
self
.
_lr_schedule
=
lr_schedule
self
.
_warmup_steps
=
warmup_steps
def
__call__
(
self
,
step
:
int
):
lr
=
self
.
_lr_schedule
(
step
)
if
self
.
_warmup_steps
:
initial_learning_rate
=
tf
.
convert_to_tensor
(
self
.
_lr_schedule
.
initial_learning_rate
,
name
=
"initial_learning_rate"
)
dtype
=
initial_learning_rate
.
dtype
global_step_recomp
=
tf
.
cast
(
step
,
dtype
)
warmup_steps
=
tf
.
cast
(
self
.
_warmup_steps
,
dtype
)
warmup_lr
=
initial_learning_rate
*
global_step_recomp
/
warmup_steps
lr
=
tf
.
cond
(
global_step_recomp
<
warmup_steps
,
lambda
:
warmup_lr
,
lambda
:
lr
)
return
lr
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
config
=
self
.
_lr_schedule
.
get_config
()
config
.
update
({
"warmup_steps"
:
self
.
_warmup_steps
,
})
return
config
# TODO(b/149030439) - refactor this with
# tf.keras.optimizers.schedules.PiecewiseConstantDecay + WarmupDecaySchedule.
class
PiecewiseConstantDecayWithWarmup
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Piecewise constant decay with warmup schedule."""
def
__init__
(
self
,
batch_size
:
int
,
epoch_size
:
int
,
warmup_epochs
:
int
,
boundaries
:
List
[
int
],
multipliers
:
List
[
float
]):
"""Piecewise constant decay with warmup.
Args:
batch_size: The training batch size used in the experiment.
epoch_size: The size of an epoch, or the number of examples in an epoch.
warmup_epochs: The number of warmup epochs to apply.
boundaries: The list of floats with strictly increasing entries.
multipliers: The list of multipliers/learning rates to use for the
piecewise portion. The length must be 1 less than that of boundaries.
"""
super
(
PiecewiseConstantDecayWithWarmup
,
self
).
__init__
()
if
len
(
boundaries
)
!=
len
(
multipliers
)
-
1
:
raise
ValueError
(
"The length of boundaries must be 1 less than the "
"length of multipliers"
)
base_lr_batch_size
=
256
steps_per_epoch
=
epoch_size
//
batch_size
self
.
_rescaled_lr
=
BASE_LEARNING_RATE
*
batch_size
/
base_lr_batch_size
self
.
_step_boundaries
=
[
float
(
steps_per_epoch
)
*
x
for
x
in
boundaries
]
self
.
_lr_values
=
[
self
.
_rescaled_lr
*
m
for
m
in
multipliers
]
self
.
_warmup_steps
=
warmup_epochs
*
steps_per_epoch
def
__call__
(
self
,
step
:
int
):
"""Compute learning rate at given step."""
def
warmup_lr
():
return
self
.
_rescaled_lr
*
(
step
/
tf
.
cast
(
self
.
_warmup_steps
,
tf
.
float32
))
def
piecewise_lr
():
return
tf
.
compat
.
v1
.
train
.
piecewise_constant
(
tf
.
cast
(
step
,
tf
.
float32
),
self
.
_step_boundaries
,
self
.
_lr_values
)
return
tf
.
cond
(
step
<
self
.
_warmup_steps
,
warmup_lr
,
piecewise_lr
)
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
{
"rescaled_lr"
:
self
.
_rescaled_lr
,
"step_boundaries"
:
self
.
_step_boundaries
,
"lr_values"
:
self
.
_lr_values
,
"warmup_steps"
:
self
.
_warmup_steps
,
}
official/vision/image_classification/learning_rate_test.py
0 → 100644
View file @
54602a66
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for learning_rate."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow.compat.v2
as
tf
from
official.vision.image_classification
import
learning_rate
class
LearningRateTests
(
tf
.
test
.
TestCase
):
def
test_warmup_decay
(
self
):
"""Basic computational test for warmup decay."""
initial_lr
=
0.01
decay_steps
=
100
decay_rate
=
0.01
warmup_steps
=
10
base_lr
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
=
initial_lr
,
decay_steps
=
decay_steps
,
decay_rate
=
decay_rate
)
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr_schedule
=
base_lr
,
warmup_steps
=
warmup_steps
)
for
step
in
range
(
warmup_steps
-
1
):
config
=
lr
.
get_config
()
self
.
assertEqual
(
config
[
'warmup_steps'
],
warmup_steps
)
self
.
assertAllClose
(
self
.
evaluate
(
lr
(
step
)),
step
/
warmup_steps
*
initial_lr
)
def
test_piecewise_constant_decay_with_warmup
(
self
):
"""Basic computational test for piecewise constant decay with warmup."""
boundaries
=
[
1
,
2
,
3
]
warmup_epochs
=
boundaries
[
0
]
learning_rate_multipliers
=
[
1.0
,
0.1
,
0.001
]
expected_keys
=
[
'rescaled_lr'
,
'step_boundaries'
,
'lr_values'
,
'warmup_steps'
,
]
expected_lrs
=
[
0.0
,
0.1
,
0.1
]
lr
=
learning_rate
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
256
,
epoch_size
=
256
,
warmup_epochs
=
warmup_epochs
,
boundaries
=
boundaries
[
1
:],
multipliers
=
learning_rate_multipliers
)
step
=
0
config
=
lr
.
get_config
()
self
.
assertAllInSet
(
list
(
config
.
keys
()),
expected_keys
)
for
boundary
,
expected_lr
in
zip
(
boundaries
,
expected_lrs
):
for
_
in
range
(
step
,
boundary
):
self
.
assertAllClose
(
self
.
evaluate
(
lr
(
step
)),
expected_lr
)
step
+=
1
def
test_piecewise_constant_decay_invalid_boundaries
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'The length of boundaries must be 1 less '
):
learning_rate
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
256
,
epoch_size
=
256
,
warmup_epochs
=
1
,
boundaries
=
[
1
,
2
],
multipliers
=
[
1
,
2
])
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
official/vision/image_classification/optimizer_factory.py
0 → 100644
View file @
54602a66
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizer factory for vision tasks."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
absl
import
logging
import
tensorflow.compat.v2
as
tf
import
tensorflow_addons
as
tfa
from
typing
import
Any
,
Dict
,
Text
from
official.vision.image_classification
import
learning_rate
from
official.vision.image_classification.configs
import
base_configs
def
build_optimizer
(
optimizer_name
:
Text
,
base_learning_rate
:
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
params
:
Dict
[
Text
,
Any
]):
"""Build the optimizer based on name.
Args:
optimizer_name: String representation of the optimizer name. Examples:
sgd, momentum, rmsprop.
base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
base learning rate.
params: String -> Any dictionary representing the optimizer params.
This should contain optimizer specific parameters such as
`base_learning_rate`, `decay`, etc.
Returns:
A tf.keras.Optimizer.
Raises:
ValueError if the provided optimizer_name is not supported.
"""
optimizer_name
=
optimizer_name
.
lower
()
logging
.
info
(
'Building %s optimizer with params %s'
,
optimizer_name
,
params
)
if
optimizer_name
==
'sgd'
:
logging
.
info
(
'Using SGD optimizer'
)
nesterov
=
params
.
get
(
'nesterov'
,
False
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
base_learning_rate
,
nesterov
=
nesterov
)
elif
optimizer_name
==
'momentum'
:
logging
.
info
(
'Using momentum optimizer'
)
nesterov
=
params
.
get
(
'nesterov'
,
False
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
base_learning_rate
,
momentum
=
params
[
'momentum'
],
nesterov
=
nesterov
)
elif
optimizer_name
==
'rmsprop'
:
logging
.
info
(
'Using RMSProp'
)
rho
=
params
.
get
(
'decay'
,
None
)
or
params
.
get
(
'rho'
,
0.9
)
momentum
=
params
.
get
(
'momentum'
,
0.9
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
(
learning_rate
=
base_learning_rate
,
rho
=
rho
,
momentum
=
momentum
,
epsilon
=
epsilon
)
elif
optimizer_name
==
'adam'
:
logging
.
info
(
'Using Adam'
)
beta_1
=
params
.
get
(
'beta_1'
,
0.9
)
beta_2
=
params
.
get
(
'beta_2'
,
0.999
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
base_learning_rate
,
beta_1
=
beta_1
,
beta_2
=
beta_2
,
epsilon
=
epsilon
)
elif
optimizer_name
==
'adamw'
:
logging
.
info
(
'Using AdamW'
)
weight_decay
=
params
.
get
(
'weight_decay'
,
0.01
)
beta_1
=
params
.
get
(
'beta_1'
,
0.9
)
beta_2
=
params
.
get
(
'beta_2'
,
0.999
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
optimizer
=
tfa
.
optimizers
.
AdamW
(
weight_decay
=
weight_decay
,
learning_rate
=
base_learning_rate
,
beta_1
=
beta_1
,
beta_2
=
beta_2
,
epsilon
=
epsilon
)
else
:
raise
ValueError
(
'Unknown optimizer %s'
%
optimizer_name
)
moving_average_decay
=
params
.
get
(
'moving_average_decay'
,
0.
)
if
moving_average_decay
is
not
None
and
moving_average_decay
>
0.
:
logging
.
info
(
'Including moving average decay.'
)
optimizer
=
tfa
.
optimizers
.
MovingAverage
(
optimizer
,
average_decay
=
params
[
'moving_average_decay'
],
num_updates
=
None
)
if
params
.
get
(
'lookahead'
,
None
):
logging
.
info
(
'Using lookahead optimizer.'
)
optimizer
=
tfa
.
optimizers
.
Lookahead
(
optimizer
)
return
optimizer
def
build_learning_rate
(
params
:
base_configs
.
LearningRateConfig
,
batch_size
:
int
=
None
,
train_steps
:
int
=
None
):
"""Build the learning rate given the provided configuration."""
decay_type
=
params
.
name
base_lr
=
params
.
initial_lr
decay_rate
=
params
.
decay_rate
if
params
.
decay_epochs
is
not
None
:
decay_steps
=
params
.
decay_epochs
*
train_steps
else
:
decay_steps
=
0
if
params
.
warmup_epochs
is
not
None
:
warmup_steps
=
params
.
warmup_epochs
*
train_steps
else
:
warmup_steps
=
0
lr_multiplier
=
params
.
scale_by_batch_size
if
lr_multiplier
and
lr_multiplier
>
0
:
# Scale the learning rate based on the batch size and a multiplier
base_lr
*=
lr_multiplier
*
batch_size
logging
.
info
(
'Scaling the learning rate based on the batch size '
'multiplier. New base_lr: %f'
,
base_lr
)
if
decay_type
==
'exponential'
:
logging
.
info
(
'Using exponential learning rate with: '
'initial_learning_rate: %f, decay_steps: %d, '
'decay_rate: %f'
,
base_lr
,
decay_steps
,
decay_rate
)
lr
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
=
base_lr
,
decay_steps
=
decay_steps
,
decay_rate
=
decay_rate
)
elif
decay_type
==
'piecewise_constant_with_warmup'
:
logging
.
info
(
'Using Piecewise constant decay with warmup. '
'Parameters: batch_size: %d, epoch_size: %d, '
'warmup_epochs: %d, boundaries: %s, multipliers: %s'
,
batch_size
,
params
.
examples_per_epoch
,
params
.
warmup_epochs
,
params
.
boundaries
,
params
.
multipliers
)
lr
=
learning_rate
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
batch_size
,
epoch_size
=
params
.
examples_per_epoch
,
warmup_epochs
=
params
.
warmup_epochs
,
boundaries
=
params
.
boundaries
,
multipliers
=
params
.
multipliers
)
if
warmup_steps
>
0
:
if
decay_type
!=
'piecewise_constant_with_warmup'
:
logging
.
info
(
'Applying %d warmup steps to the learning rate'
,
warmup_steps
)
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr
,
warmup_steps
)
return
lr
official/vision/image_classification/optimizer_factory_test.py
0 → 100644
View file @
54602a66
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizer_factory."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
tensorflow.compat.v2
as
tf
from
absl.testing
import
parameterized
from
official.vision.image_classification
import
optimizer_factory
from
official.vision.image_classification.configs
import
base_configs
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
(
(
'sgd'
,
'sgd'
,
0.
,
False
),
(
'momentum'
,
'momentum'
,
0.
,
False
),
(
'rmsprop'
,
'rmsprop'
,
0.
,
False
),
(
'adam'
,
'adam'
,
0.
,
False
),
(
'adamw'
,
'adamw'
,
0.
,
False
),
(
'momentum_lookahead'
,
'momentum'
,
0.
,
True
),
(
'sgd_ema'
,
'sgd'
,
0.001
,
False
),
(
'momentum_ema'
,
'momentum'
,
0.001
,
False
),
(
'rmsprop_ema'
,
'rmsprop'
,
0.001
,
False
))
def
test_optimizer
(
self
,
optimizer_name
,
moving_average_decay
,
lookahead
):
"""Smoke test to be sure no syntax errors."""
params
=
{
'learning_rate'
:
0.001
,
'rho'
:
0.09
,
'momentum'
:
0.
,
'epsilon'
:
1e-07
,
'moving_average_decay'
:
moving_average_decay
,
'lookahead'
:
lookahead
,
}
optimizer
=
optimizer_factory
.
build_optimizer
(
optimizer_name
=
optimizer_name
,
base_learning_rate
=
params
[
'learning_rate'
],
params
=
params
)
self
.
assertTrue
(
issubclass
(
type
(
optimizer
),
tf
.
keras
.
optimizers
.
Optimizer
))
def
test_unknown_optimizer
(
self
):
with
self
.
assertRaises
(
ValueError
):
optimizer_factory
.
build_optimizer
(
optimizer_name
=
'this_optimizer_does_not_exist'
,
base_learning_rate
=
None
,
params
=
None
)
def
test_learning_rate_without_decay_or_warmups
(
self
):
params
=
base_configs
.
LearningRateConfig
(
name
=
'exponential'
,
initial_lr
=
0.01
,
decay_rate
=
0.01
,
decay_epochs
=
None
,
warmup_epochs
=
None
,
scale_by_batch_size
=
0.01
,
examples_per_epoch
=
1
,
boundaries
=
[
0
],
multipliers
=
[
0
,
1
])
batch_size
=
1
train_steps
=
1
lr
=
optimizer_factory
.
build_learning_rate
(
params
=
params
,
batch_size
=
batch_size
,
train_steps
=
train_steps
)
self
.
assertTrue
(
issubclass
(
type
(
lr
),
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
))
@
parameterized
.
named_parameters
(
(
'exponential'
,
'exponential'
),
(
'piecewise_constant_with_warmup'
,
'piecewise_constant_with_warmup'
))
def
test_learning_rate_with_decay_and_warmup
(
self
,
lr_decay_type
):
"""Basic smoke test for syntax."""
params
=
base_configs
.
LearningRateConfig
(
name
=
lr_decay_type
,
initial_lr
=
0.01
,
decay_rate
=
0.01
,
decay_epochs
=
1
,
warmup_epochs
=
1
,
scale_by_batch_size
=
0.01
,
examples_per_epoch
=
1
,
boundaries
=
[
0
],
multipliers
=
[
0
,
1
])
batch_size
=
1
train_steps
=
1
lr
=
optimizer_factory
.
build_learning_rate
(
params
=
params
,
batch_size
=
batch_size
,
train_steps
=
train_steps
)
self
.
assertTrue
(
issubclass
(
type
(
lr
),
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
))
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
official/vision/image_classification/preprocessing.py
0 → 100644
View file @
54602a66
This diff is collapsed.
Click to expand it.
official/vision/image_classification/resnet/README.md
0 → 100644
View file @
54602a66
This folder contains a compile/fit and
[
custom training loop (CTL)
](
#resnet-custom-training-loop
)
implementation for
ResNet50.
## Before you begin
Please refer to the
[
README
](
../README.md
)
in the parent directory for
information on setup and preparing the data.
## ResNet (custom training loop)
Similar to the
[
estimator implementation
](
../../../r1/resnet
)
, the Keras
implementation has code for the ImageNet dataset. The ImageNet
version uses a ResNet50 model implemented in
[
`resnet_model.py`
](
./resnet_model.py
)
.
### Pretrained Models
*
[
ResNet50 Checkpoints
](
https://storage.googleapis.com/cloud-tpu-checkpoints/resnet/resnet50.tar.gz
)
*
ResNet50 TFHub:
[
feature vector
](
https://tfhub.dev/tensorflow/resnet_50/feature_vector/1
)
and
[
classification
](
https://tfhub.dev/tensorflow/resnet_50/classification/1
)
```
bash
python3 resnet_imagenet_main.py
```
Again, if you did not download the data to the default directory, specify the
location with the
`--data_dir`
flag:
```
bash
python3 resnet_imagenet_main.py
--data_dir
=
/path/to/imagenet
```
There are more flag options you can specify. Here are some examples:
-
`--use_synthetic_data`
: when set to true, synthetic data, rather than real
data, are used;
-
`--batch_size`
: the batch size used for the model;
-
`--model_dir`
: the directory to save the model checkpoint;
-
`--train_epochs`
: number of epoches to run for training the model;
-
`--train_steps`
: number of steps to run for training the model. We now only
support a number that is smaller than the number of batches in an epoch.
-
`--skip_eval`
: when set to true, evaluation as well as validation during
training is skipped
For example, this is a typical command line to run with ImageNet data with
batch size 128 per GPU:
```
bash
python3
-m
resnet_imagenet_main.py
\
--model_dir
=
/tmp/model_dir/something
\
--num_gpus
=
2
\
--batch_size
=
128
\
--train_epochs
=
90
\
--train_steps
=
10
\
--use_synthetic_data
=
false
```
See
[
`common.py`
](
common.py
)
for full list of options.
### Using multiple GPUs
You can train these models on multiple GPUs using
`tf.distribute.Strategy`
API.
You can read more about them in this
[
guide
](
https://www.tensorflow.org/guide/distribute_strategy
)
.
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`
. By default this flag is 1 if TensorFlow is compiled with CUDA,
and 0 otherwise.
-
--num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
-
--num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
-
--num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
distributed training across the GPUs.
If you wish to run without
`tf.distribute.Strategy`
, you can do so by setting
`--distribution_strategy=off`
.
### Running on multiple GPU hosts
You can also train these models on multiple hosts, each with GPUs, using
`tf.distribute.Strategy`
.
The easiest way to run multi-host benchmarks is to set the
[
`TF_CONFIG`
](
https://www.tensorflow.org/guide/distributed_training#TF_CONFIG
)
appropriately at each host. e.g., to run using
`MultiWorkerMirroredStrategy`
on
2 hosts, the
`cluster`
in
`TF_CONFIG`
should have 2
`host:port`
entries, and
host
`i`
should have the
`task`
in
`TF_CONFIG`
set to
`{"type": "worker",
"index": i}`
.
`MultiWorkerMirroredStrategy`
will automatically use all the
available GPUs at each host.
### Running on Cloud TPUs
Note: This model will
**not**
work with TPUs on Colab.
You can train the ResNet CTL model on Cloud TPUs using
`tf.distribute.TPUStrategy`
. If you are not familiar with Cloud TPUs, it is
strongly recommended that you go through the
[
quickstart
](
https://cloud.google.com/tpu/docs/quickstart
)
to learn how to
create a TPU and GCE VM.
To run ResNet model on a TPU, you must set
`--distribution_strategy=tpu`
and
`--tpu=$TPU_NAME`
, where
`$TPU_NAME`
the name of your TPU in the Cloud Console.
From a GCE VM, you can run the following command to train ResNet for one epoch
on a v2-8 or v3-8 TPU by setting
`TRAIN_EPOCHS`
to 1:
```
bash
python3 resnet_ctl_imagenet_main.py
\
--tpu
=
$TPU_NAME
\
--model_dir
=
$MODEL_DIR
\
--data_dir
=
$DATA_DIR
\
--batch_size
=
1024
\
--steps_per_loop
=
500
\
--train_epochs
=
$TRAIN_EPOCHS
\
--use_synthetic_data
=
false
\
--dtype
=
fp32
\
--enable_eager
=
true
\
--enable_tensorboard
=
true
\
--distribution_strategy
=
tpu
\
--log_steps
=
50
\
--single_l2_loss_op
=
true
\
--use_tf_function
=
true
```
To train the ResNet to convergence, run it for 90 epochs by setting
`TRAIN_EPOCHS`
to 90.
Note:
`$MODEL_DIR`
and
`$DATA_DIR`
must be GCS paths.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment