Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d2a038fa
Commit
d2a038fa
authored
Aug 21, 2019
by
Zongwei Zhou
Committed by
A. Unique TensorFlower
Aug 21, 2019
Browse files
Internal change
PiperOrigin-RevId: 264647492
parent
dd03f167
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
270 additions
and
106 deletions
+270
-106
official/resnet/ctl/ctl_common.py
official/resnet/ctl/ctl_common.py
+3
-0
official/resnet/ctl/ctl_imagenet_main.py
official/resnet/ctl/ctl_imagenet_main.py
+24
-8
official/vision/image_classification/resnet_model.py
official/vision/image_classification/resnet_model.py
+243
-98
No files found.
official/resnet/ctl/ctl_common.py
View file @
d2a038fa
...
...
@@ -27,3 +27,6 @@ def define_ctl_flags():
flags
.
DEFINE_boolean
(
name
=
'use_tf_function'
,
default
=
True
,
help
=
'Wrap the train and test step inside a '
'tf.function.'
)
flags
.
DEFINE_boolean
(
name
=
'single_l2_loss_op'
,
default
=
False
,
help
=
'Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.'
)
official/resnet/ctl/ctl_imagenet_main.py
View file @
d2a038fa
...
...
@@ -137,6 +137,10 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
# TODO(anj-s): Set data_format without using Keras.
...
...
@@ -163,7 +167,8 @@ def run(flags_obj):
with
strategy_scope
:
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
=
dtype
,
batch_size
=
flags_obj
.
batch_size
)
dtype
=
dtype
,
batch_size
=
flags_obj
.
batch_size
,
use_l2_regularizer
=
not
flags_obj
.
single_l2_loss_op
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
keras_common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
...
...
@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'test_accuracy'
,
dtype
=
tf
.
float32
)
trainable_variables
=
model
.
trainable_variables
def
train_step
(
train_ds_inputs
):
"""Training StepFn."""
def
step_fn
(
inputs
):
...
...
@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
logits
)
loss1
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
flags_obj
.
batch_size
)
loss2
=
(
tf
.
reduce_sum
(
model
.
losses
)
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
)
loss
=
loss1
+
loss2
grads
=
tape
.
gradient
(
loss
,
model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
model
.
trainable_variables
))
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
flags_obj
.
batch_size
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
if
flags_obj
.
single_l2_loss_op
:
filtered_variables
=
[
tf
.
reshape
(
v
,
(
-
1
,))
for
v
in
trainable_variables
if
'bn'
not
in
v
.
name
]
l2_loss
=
resnet_model
.
L2_WEIGHT_DECAY
*
2
*
tf
.
nn
.
l2_loss
(
tf
.
concat
(
filtered_variables
,
axis
=
0
))
loss
+=
(
l2_loss
/
num_replicas
)
else
:
loss
+=
(
tf
.
reduce_sum
(
model
.
losses
)
/
num_replicas
)
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
training_accuracy
.
update_state
(
labels
,
logits
)
return
loss
...
...
official/vision/image_classification/resnet_model.py
View file @
d2a038fa
...
...
@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON
=
1e-5
def
identity_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
):
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
):
return
regularizers
.
l2
(
L2_WEIGHT_DECAY
)
if
use_l2_regularizer
else
None
def
identity_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
,
use_l2_regularizer
=
True
):
"""The identity block is the block that has no conv layer at shortcut.
Args:
...
...
@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Output tensor for the block.
...
...
@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
padding
=
'same'
,
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
name
=
bn_name_base
+
'2c'
)(
x
)
x
=
layers
.
add
([
x
,
input_tensor
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
...
@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters
,
stage
,
block
,
strides
=
(
2
,
2
)):
strides
=
(
2
,
2
),
use_l2_regularizer
=
True
):
"""A block that has a conv layer at shortcut.
Note that from stage 3,
...
...
@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Output tensor for the block.
...
...
@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
strides
=
strides
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
strides
=
strides
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2c'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
name
=
bn_name_base
+
'2c'
)(
x
)
shortcut
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
strides
=
strides
,
use_bias
=
False
,
shortcut
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
strides
=
strides
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
conv_name_base
+
'1'
)(
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'1'
)(
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'1'
)(
shortcut
)
name
=
bn_name_base
+
'1'
)(
shortcut
)
x
=
layers
.
add
([
x
,
shortcut
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
return
x
def
resnet50
(
num_classes
,
dtype
=
'float32'
,
batch_size
=
None
):
def
resnet50
(
num_classes
,
dtype
=
'float32'
,
batch_size
=
None
,
use_l2_regularizer
=
True
):
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns:
A Keras model instance.
"""
input_shape
=
(
224
,
224
,
3
)
img_input
=
layers
.
Input
(
shape
=
input_shape
,
dtype
=
dtype
,
batch_size
=
batch_size
)
img_input
=
layers
.
Input
(
shape
=
input_shape
,
dtype
=
dtype
,
batch_size
=
batch_size
)
if
backend
.
image_data_format
()
==
'channels_first'
:
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
permute_dimensions
(
x
,
(
0
,
3
,
1
,
2
)),
name
=
'transpose'
)(
img_input
)
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
permute_dimensions
(
x
,
(
0
,
3
,
1
,
2
)),
name
=
'transpose'
)(
img_input
)
bn_axis
=
1
else
:
# channels_last
x
=
img_input
bn_axis
=
3
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
strides
=
(
2
,
2
),
padding
=
'valid'
,
use_bias
=
False
,
padding
=
'valid'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'conv1'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
'conv1'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
'bn_conv1'
)(
x
)
name
=
'bn_conv1'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
conv_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
))
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
)
x
=
conv_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
),
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
rm_axes
=
[
1
,
2
]
if
backend
.
image_data_format
()
==
'channels_last'
else
[
2
,
3
]
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
mean
(
x
,
rm_axes
),
name
=
'reduce_mean'
)(
x
)
x
=
layers
.
Dense
(
num_classes
,
kernel_initializer
=
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
name
=
'fc1000'
)(
x
)
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
bias_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
'fc1000'
)(
x
)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment