Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
39774bc8
Commit
39774bc8
authored
Nov 11, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 341852981
parent
c061dace
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
21 deletions
+47
-21
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+2
-1
official/vision/beta/modeling/backbones/resnet.py
official/vision/beta/modeling/backbones/resnet.py
+12
-1
official/vision/beta/modeling/backbones/resnet_test.py
official/vision/beta/modeling/backbones/resnet_test.py
+10
-6
official/vision/beta/modeling/backbones/spinenet.py
official/vision/beta/modeling/backbones/spinenet.py
+3
-13
official/vision/beta/modeling/layers/nn_layers.py
official/vision/beta/modeling/layers/nn_layers.py
+20
-0
No files found.
official/vision/beta/configs/backbones.py
View file @
39774bc8
...
...
@@ -28,6 +28,7 @@ class ResNet(hyperparams.Config):
model_id
:
int
=
50
stem_type
:
str
=
'v0'
se_ratio
:
float
=
0.0
stochastic_depth_drop_rate
:
float
=
0.0
@
dataclasses
.
dataclass
...
...
@@ -41,8 +42,8 @@ class DilatedResNet(hyperparams.Config):
class
EfficientNet
(
hyperparams
.
Config
):
"""EfficientNet config."""
model_id
:
str
=
'b0'
stochastic_depth_drop_rate
:
float
=
0.0
se_ratio
:
float
=
0.0
stochastic_depth_drop_rate
:
float
=
0.0
@
dataclasses
.
dataclass
...
...
official/vision/beta/modeling/backbones/resnet.py
View file @
39774bc8
...
...
@@ -24,6 +24,7 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_layers
layers
=
tf
.
keras
.
layers
...
...
@@ -80,6 +81,7 @@ class ResNet(tf.keras.Model):
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
stem_type
=
'v0'
,
se_ratio
=
None
,
init_stochastic_depth_rate
=
0.0
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
...
...
@@ -96,6 +98,7 @@ class ResNet(tf.keras.Model):
stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`,
use ResNet-C type stem (https://arxiv.org/abs/1812.01187).
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: `float` initial stochastic depth rate.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
...
...
@@ -112,6 +115,7 @@ class ResNet(tf.keras.Model):
self
.
_input_specs
=
input_specs
self
.
_stem_type
=
stem_type
self
.
_se_ratio
=
se_ratio
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_norm_momentum
=
norm_momentum
...
...
@@ -195,7 +199,6 @@ class ResNet(tf.keras.Model):
x
=
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'same'
)(
x
)
# TODO(xianzhi): keep a list of blocks to make blocks accessible.
endpoints
=
{}
for
i
,
spec
in
enumerate
(
RESNET_SPECS
[
model_id
]):
if
spec
[
0
]
==
'residual'
:
...
...
@@ -210,6 +213,8 @@ class ResNet(tf.keras.Model):
strides
=
(
1
if
i
==
0
else
2
),
block_fn
=
block_fn
,
block_repeats
=
spec
[
2
],
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
2
,
5
),
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
+
2
)]
=
x
...
...
@@ -223,6 +228,7 @@ class ResNet(tf.keras.Model):
strides
,
block_fn
,
block_repeats
=
1
,
stochastic_depth_drop_rate
=
0.0
,
name
=
'block_group'
):
"""Creates one group of blocks for the ResNet model.
...
...
@@ -233,6 +239,7 @@ class ResNet(tf.keras.Model):
greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
stochastic_depth_drop_rate: `float` drop rate of the current block group.
name: `str`name for the block.
Returns:
...
...
@@ -242,6 +249,7 @@ class ResNet(tf.keras.Model):
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
se_ratio
=
self
.
_se_ratio
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
...
@@ -257,6 +265,7 @@ class ResNet(tf.keras.Model):
filters
=
filters
,
strides
=
1
,
use_projection
=
False
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
se_ratio
=
self
.
_se_ratio
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
...
...
@@ -275,6 +284,7 @@ class ResNet(tf.keras.Model):
'stem_type'
:
self
.
_stem_type
,
'activation'
:
self
.
_activation
,
'se_ratio'
:
self
.
_se_ratio
,
'init_stochastic_depth_rate'
:
self
.
_init_stochastic_depth_rate
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
,
...
...
@@ -311,6 +321,7 @@ def build_resnet(
input_specs
=
input_specs
,
stem_type
=
backbone_cfg
.
stem_type
,
se_ratio
=
backbone_cfg
.
se_ratio
,
init_stochastic_depth_rate
=
backbone_cfg
.
stochastic_depth_drop_rate
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
...
...
official/vision/beta/modeling/backbones/resnet_test.py
View file @
39774bc8
...
...
@@ -84,17 +84,20 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
_
=
network
(
inputs
)
@
parameterized
.
parameters
(
(
128
,
34
,
1
,
'v0'
,
None
),
(
128
,
34
,
1
,
'v1'
,
0.25
),
(
128
,
50
,
4
,
'v0'
,
None
),
(
128
,
50
,
4
,
'v1'
,
0.25
),
(
128
,
34
,
1
,
'v0'
,
None
,
0.0
),
(
128
,
34
,
1
,
'v1'
,
0.25
,
0.2
),
(
128
,
50
,
4
,
'v0'
,
None
,
0.0
),
(
128
,
50
,
4
,
'v1'
,
0.25
,
0.2
),
)
def
test_resnet_addons
(
self
,
input_size
,
model_id
,
endpoint_filter_scale
,
stem_type
,
se_ratio
):
stem_type
,
se_ratio
,
init_stochastic_depth_rate
):
"""Test creation of ResNet family models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
network
=
resnet
.
ResNet
(
model_id
=
model_id
,
stem_type
=
stem_type
,
se_ratio
=
se_ratio
)
model_id
=
model_id
,
stem_type
=
stem_type
,
se_ratio
=
se_ratio
,
init_stochastic_depth_rate
=
init_stochastic_depth_rate
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
_
=
network
(
inputs
)
...
...
@@ -115,6 +118,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
model_id
=
50
,
stem_type
=
'v0'
,
se_ratio
=
None
,
init_stochastic_depth_rate
=
0.0
,
use_sync_bn
=
False
,
activation
=
'relu'
,
norm_momentum
=
0.99
,
...
...
official/vision/beta/modeling/backbones/spinenet.py
View file @
39774bc8
...
...
@@ -27,6 +27,7 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.ops
import
spatial_transform_ops
layers
=
tf
.
keras
.
layers
...
...
@@ -114,17 +115,6 @@ def build_block_specs(block_specs=None):
return
[
BlockSpec
(
*
b
)
for
b
in
block_specs
]
def
get_stochastic_depth_rate
(
init_rate
,
i
,
n
):
"""Get drop connect rate for the ith block."""
if
init_rate
is
not
None
:
if
init_rate
<
0
or
init_rate
>
1
:
raise
ValueError
(
'Initial drop rate must be within 0 and 1.'
)
dc_rate
=
init_rate
*
float
(
i
+
1
)
/
n
else
:
dc_rate
=
None
return
dc_rate
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SpineNet
(
tf
.
keras
.
Model
):
"""Class to build SpineNet models."""
...
...
@@ -350,8 +340,8 @@ class SpineNet(tf.keras.Model):
strides
=
1
,
block_fn_cand
=
target_block_fn
,
block_repeats
=
self
.
_block_repeats
,
stochastic_depth_drop_rate
=
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
,
len
(
self
.
_block_specs
)),
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
1
,
len
(
self
.
_block_specs
)),
name
=
'scale_permuted_block_{}'
.
format
(
i
+
1
))
net
.
append
(
x
)
...
...
official/vision/beta/modeling/layers/nn_layers.py
View file @
39774bc8
...
...
@@ -167,6 +167,26 @@ class SqueezeExcitation(tf.keras.layers.Layer):
return
x
*
inputs
def
get_stochastic_depth_rate
(
init_rate
,
i
,
n
):
"""Get drop connect rate for the ith block.
Args:
init_rate: `float` initial drop rate.
i: `int` order of the current block.
n: `int` total number of blocks.
Returns:
Drop rate of the ith block.
"""
if
init_rate
is
not
None
:
if
init_rate
<
0
or
init_rate
>
1
:
raise
ValueError
(
'Initial drop rate must be within 0 and 1.'
)
rate
=
init_rate
*
float
(
i
)
/
n
else
:
rate
=
None
return
rate
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
StochasticDepth
(
tf
.
keras
.
layers
.
Layer
):
"""Stochastic depth layer."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment