Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7a257585
Commit
7a257585
authored
Mar 25, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 25, 2020
Browse files
Readability: Avoid global variables in resnet construction.
PiperOrigin-RevId: 302932162
parent
ad09cf49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
126 deletions
+58
-126
official/vision/image_classification/resnet/resnet_model.py
official/vision/image_classification/resnet/resnet_model.py
+56
-124
official/vision/image_classification/resnet/resnet_runnable.py
...ial/vision/image_classification/resnet/resnet_runnable.py
+2
-2
No files found.
official/vision/image_classification/resnet/resnet_model.py
View file @
7a257585
...
@@ -35,15 +35,11 @@ from tensorflow.python.keras import models
...
@@ -35,15 +35,11 @@ from tensorflow.python.keras import models
from
tensorflow.python.keras
import
regularizers
from
tensorflow.python.keras
import
regularizers
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
L2_WEIGHT_DECAY
=
1e-4
BATCH_NORM_DECAY
=
0.9
BATCH_NORM_EPSILON
=
1e-5
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
):
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
,
l2_weight_decay
=
1e-4
):
return
regularizers
.
l2
(
L
2_
WEIGHT_DECAY
)
if
use_l2_regularizer
else
None
return
regularizers
.
l2
(
l
2_
weight_decay
)
if
use_l2_regularizer
else
None
def
identity_block
(
input_tensor
,
def
identity_block
(
input_tensor
,
...
@@ -51,7 +47,9 @@ def identity_block(input_tensor,
...
@@ -51,7 +47,9 @@ def identity_block(input_tensor,
filters
,
filters
,
stage
,
stage
,
block
,
block
,
use_l2_regularizer
=
True
):
use_l2_regularizer
=
True
,
batch_norm_decay
=
0.9
,
batch_norm_epsilon
=
1e-5
):
"""The identity block is the block that has no conv layer at shortcut.
"""The identity block is the block that has no conv layer at shortcut.
Args:
Args:
...
@@ -61,6 +59,8 @@ def identity_block(input_tensor,
...
@@ -61,6 +59,8 @@ def identity_block(input_tensor,
stage: integer, current stage label, used for generating layer names
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Returns:
Output tensor for the block.
Output tensor for the block.
...
@@ -82,8 +82,8 @@ def identity_block(input_tensor,
...
@@ -82,8 +82,8 @@ def identity_block(input_tensor,
input_tensor
)
input_tensor
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2a'
)(
name
=
bn_name_base
+
'2a'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -99,8 +99,8 @@ def identity_block(input_tensor,
...
@@ -99,8 +99,8 @@ def identity_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2b'
)(
name
=
bn_name_base
+
'2b'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -114,8 +114,8 @@ def identity_block(input_tensor,
...
@@ -114,8 +114,8 @@ def identity_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2c'
)(
name
=
bn_name_base
+
'2c'
)(
x
)
x
)
...
@@ -130,7 +130,9 @@ def conv_block(input_tensor,
...
@@ -130,7 +130,9 @@ def conv_block(input_tensor,
stage
,
stage
,
block
,
block
,
strides
=
(
2
,
2
),
strides
=
(
2
,
2
),
use_l2_regularizer
=
True
):
use_l2_regularizer
=
True
,
batch_norm_decay
=
0.9
,
batch_norm_epsilon
=
1e-5
):
"""A block that has a conv layer at shortcut.
"""A block that has a conv layer at shortcut.
Note that from stage 3,
Note that from stage 3,
...
@@ -145,6 +147,8 @@ def conv_block(input_tensor,
...
@@ -145,6 +147,8 @@ def conv_block(input_tensor,
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Returns:
Output tensor for the block.
Output tensor for the block.
...
@@ -166,8 +170,8 @@ def conv_block(input_tensor,
...
@@ -166,8 +170,8 @@ def conv_block(input_tensor,
input_tensor
)
input_tensor
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2a'
)(
name
=
bn_name_base
+
'2a'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -184,8 +188,8 @@ def conv_block(input_tensor,
...
@@ -184,8 +188,8 @@ def conv_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2b'
)(
name
=
bn_name_base
+
'2b'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -199,8 +203,8 @@ def conv_block(input_tensor,
...
@@ -199,8 +203,8 @@ def conv_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2c'
)(
name
=
bn_name_base
+
'2c'
)(
x
)
x
)
...
@@ -214,8 +218,8 @@ def conv_block(input_tensor,
...
@@ -214,8 +218,8 @@ def conv_block(input_tensor,
input_tensor
)
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'1'
)(
name
=
bn_name_base
+
'1'
)(
shortcut
)
shortcut
)
...
@@ -227,7 +231,9 @@ def conv_block(input_tensor,
...
@@ -227,7 +231,9 @@ def conv_block(input_tensor,
def
resnet50
(
num_classes
,
def
resnet50
(
num_classes
,
batch_size
=
None
,
batch_size
=
None
,
use_l2_regularizer
=
True
,
use_l2_regularizer
=
True
,
rescale_inputs
=
False
):
rescale_inputs
=
False
,
batch_norm_decay
=
0.9
,
batch_norm_epsilon
=
1e-5
):
"""Instantiates the ResNet50 architecture.
"""Instantiates the ResNet50 architecture.
Args:
Args:
...
@@ -235,6 +241,8 @@ def resnet50(num_classes,
...
@@ -235,6 +241,8 @@ def resnet50(num_classes,
batch_size: Size of the batches for each step.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
rescale_inputs: whether to rescale inputs from 0 to 1.
rescale_inputs: whether to rescale inputs from 0 to 1.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Returns:
A Keras model instance.
A Keras model instance.
...
@@ -260,6 +268,10 @@ def resnet50(num_classes,
...
@@ -260,6 +268,10 @@ def resnet50(num_classes,
else
:
# channels_last
else
:
# channels_last
bn_axis
=
3
bn_axis
=
3
block_config
=
dict
(
use_l2_regularizer
=
use_l2_regularizer
,
batch_norm_decay
=
batch_norm_decay
,
batch_norm_epsilon
=
batch_norm_epsilon
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
Conv2D
(
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
64
,
(
7
,
7
),
...
@@ -272,113 +284,33 @@ def resnet50(num_classes,
...
@@ -272,113 +284,33 @@ def resnet50(num_classes,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
'bn_conv1'
)(
name
=
'bn_conv1'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
conv_block
(
x
=
conv_block
(
x
,
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
),
**
block_config
)
3
,
[
64
,
64
,
256
],
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
,
**
block_config
)
stage
=
2
,
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
,
**
block_config
)
block
=
'a'
,
strides
=
(
1
,
1
),
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
,
**
block_config
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
,
**
block_config
)
x
=
identity_block
(
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
,
**
block_config
)
x
,
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
,
**
block_config
)
3
,
[
64
,
64
,
256
],
stage
=
2
,
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
,
**
block_config
)
block
=
'b'
,
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
,
**
block_config
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
,
**
block_config
)
x
=
identity_block
(
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
,
**
block_config
)
x
,
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
,
**
block_config
)
3
,
[
64
,
64
,
256
],
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
,
**
block_config
)
stage
=
2
,
block
=
'c'
,
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
,
**
block_config
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
,
**
block_config
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
,
**
block_config
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
layers
.
GlobalAveragePooling2D
()(
x
)
x
=
layers
.
GlobalAveragePooling2D
()(
x
)
x
=
layers
.
Dense
(
x
=
layers
.
Dense
(
...
...
official/vision/image_classification/resnet/resnet_runnable.py
View file @
7a257585
...
@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
self
.
flags_obj
.
batch_size
)
self
.
flags_obj
.
batch_size
)
num_replicas
=
self
.
strategy
.
num_replicas_in_sync
num_replicas
=
self
.
strategy
.
num_replicas_in_sync
l2_weight_decay
=
1e-4
if
self
.
flags_obj
.
single_l2_loss_op
:
if
self
.
flags_obj
.
single_l2_loss_op
:
l2_loss
=
resnet_model
.
L2_WEIGHT_DECAY
*
2
*
tf
.
add_n
([
l2_loss
=
l2_weight_decay
*
2
*
tf
.
add_n
([
tf
.
nn
.
l2_loss
(
v
)
tf
.
nn
.
l2_loss
(
v
)
for
v
in
self
.
model
.
trainable_variables
for
v
in
self
.
model
.
trainable_variables
if
'bn'
not
in
v
.
name
if
'bn'
not
in
v
.
name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment