Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0168e493
Commit
0168e493
authored
Oct 16, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Oct 16, 2020
Browse files
Internal change
PiperOrigin-RevId: 337456452
parent
4d85af94
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
387 additions
and
0 deletions
+387
-0
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+9
-0
official/vision/beta/modeling/backbones/__init__.py
official/vision/beta/modeling/backbones/__init__.py
+1
-0
official/vision/beta/modeling/backbones/resnet_deeplab.py
official/vision/beta/modeling/backbones/resnet_deeplab.py
+261
-0
official/vision/beta/modeling/backbones/resnet_deeplab_test.py
...ial/vision/beta/modeling/backbones/resnet_deeplab_test.py
+111
-0
official/vision/beta/modeling/layers/nn_blocks.py
official/vision/beta/modeling/layers/nn_blocks.py
+5
-0
No files found.
official/vision/beta/configs/backbones.py
View file @
0168e493
...
@@ -28,6 +28,13 @@ class ResNet(hyperparams.Config):
...
@@ -28,6 +28,13 @@ class ResNet(hyperparams.Config):
model_id
:
int
=
50
model_id
:
int
=
50
@
dataclasses
.
dataclass
class
DilatedResNet
(
hyperparams
.
Config
):
"""DilatedResNet config."""
model_id
:
int
=
50
output_stride
:
int
=
16
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
EfficientNet
(
hyperparams
.
Config
):
class
EfficientNet
(
hyperparams
.
Config
):
"""EfficientNet config."""
"""EfficientNet config."""
...
@@ -65,6 +72,7 @@ class Backbone(hyperparams.OneOfConfig):
...
@@ -65,6 +72,7 @@ class Backbone(hyperparams.OneOfConfig):
Attributes:
Attributes:
type: 'str', type of backbone be used, one the of fields below.
type: 'str', type of backbone be used, one the of fields below.
resnet: resnet backbone config.
resnet: resnet backbone config.
dilated_resnet: dilated resnet backbone for semantic segmentation config.
revnet: revnet backbone config.
revnet: revnet backbone config.
efficientnet: efficientnet backbone config.
efficientnet: efficientnet backbone config.
spinenet: spinenet backbone config.
spinenet: spinenet backbone config.
...
@@ -72,6 +80,7 @@ class Backbone(hyperparams.OneOfConfig):
...
@@ -72,6 +80,7 @@ class Backbone(hyperparams.OneOfConfig):
"""
"""
type
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
resnet
:
ResNet
=
ResNet
()
resnet
:
ResNet
=
ResNet
()
dilated_resnet
:
DilatedResNet
=
DilatedResNet
()
revnet
:
RevNet
=
RevNet
()
revnet
:
RevNet
=
RevNet
()
efficientnet
:
EfficientNet
=
EfficientNet
()
efficientnet
:
EfficientNet
=
EfficientNet
()
spinenet
:
SpineNet
=
SpineNet
()
spinenet
:
SpineNet
=
SpineNet
()
...
...
official/vision/beta/modeling/backbones/__init__.py
View file @
0168e493
...
@@ -19,5 +19,6 @@ from official.vision.beta.modeling.backbones.efficientnet import EfficientNet
...
@@ -19,5 +19,6 @@ from official.vision.beta.modeling.backbones.efficientnet import EfficientNet
from
official.vision.beta.modeling.backbones.mobilenet
import
MobileNet
from
official.vision.beta.modeling.backbones.mobilenet
import
MobileNet
from
official.vision.beta.modeling.backbones.resnet
import
ResNet
from
official.vision.beta.modeling.backbones.resnet
import
ResNet
from
official.vision.beta.modeling.backbones.resnet_3d
import
ResNet3D
from
official.vision.beta.modeling.backbones.resnet_3d
import
ResNet3D
from
official.vision.beta.modeling.backbones.resnet_deeplab
import
DilatedResNet
from
official.vision.beta.modeling.backbones.revnet
import
RevNet
from
official.vision.beta.modeling.backbones.revnet
import
RevNet
from
official.vision.beta.modeling.backbones.spinenet
import
SpineNet
from
official.vision.beta.modeling.backbones.spinenet
import
SpineNet
official/vision/beta/modeling/backbones/resnet_deeplab.py
0 → 100644
View file @
0168e493
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions of Residual Networks with Deeplab modifications."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
layers
=
tf
.
keras
.
layers
# Specifications for different ResNet variants.
# Each entry specifies block configurations of the particular ResNet variant.
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, block_repeats)
RESNET_SPECS
=
{
50
:
[
(
'bottleneck'
,
64
,
3
),
(
'bottleneck'
,
128
,
4
),
(
'bottleneck'
,
256
,
6
),
(
'bottleneck'
,
512
,
3
),
],
101
:
[
(
'bottleneck'
,
64
,
3
),
(
'bottleneck'
,
128
,
4
),
(
'bottleneck'
,
256
,
23
),
(
'bottleneck'
,
512
,
3
),
],
}
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
DilatedResNet
(
tf
.
keras
.
Model
):
"""Class to build ResNet model with Deeplabv3 modifications.
This backbone is suitable for semantic segmentation. It was proposed in:
[1] Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam
Rethinking Atrous Convolution for Semantic Image Segmentation.
arXiv:1706.05587
"""
def
__init__
(
self
,
model_id
,
output_stride
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""ResNet with DeepLab modification initialization function.
Args:
model_id: `int` depth of ResNet backbone model.
output_stride: `int` output stride, ratio of input to output resolution.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
**kwargs: keyword arguments to be passed.
"""
self
.
_model_id
=
model_id
self
.
_output_stride
=
output_stride
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
layers
.
BatchNormalization
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
else
:
bn_axis
=
1
# Build ResNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
7
,
strides
=
2
,
use_bias
=
False
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
x
=
layers
.
MaxPool2D
(
pool_size
=
3
,
strides
=
2
,
padding
=
'same'
)(
x
)
normal_resnet_stage
=
int
(
np
.
math
.
log2
(
self
.
_output_stride
))
-
2
endpoints
=
{}
for
i
in
range
(
normal_resnet_stage
+
1
):
spec
=
RESNET_SPECS
[
model_id
][
i
]
if
spec
[
0
]
==
'bottleneck'
:
block_fn
=
nn_blocks
.
BottleneckBlock
else
:
raise
ValueError
(
'Block fn `{}` is not supported.'
.
format
(
spec
[
0
]))
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
spec
[
1
],
strides
=
(
1
if
i
==
0
else
2
),
dilation_rate
=
1
,
block_fn
=
block_fn
,
block_repeats
=
spec
[
2
],
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
+
2
)]
=
x
dilation_rate
=
2
for
i
in
range
(
normal_resnet_stage
+
1
,
7
):
spec
=
RESNET_SPECS
[
model_id
][
i
]
if
i
<
3
else
RESNET_SPECS
[
model_id
][
-
1
]
if
spec
[
0
]
==
'bottleneck'
:
block_fn
=
nn_blocks
.
BottleneckBlock
else
:
raise
ValueError
(
'Block fn `{}` is not supported.'
.
format
(
spec
[
0
]))
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
spec
[
1
],
strides
=
1
,
dilation_rate
=
dilation_rate
,
block_fn
=
block_fn
,
block_repeats
=
spec
[
2
],
name
=
'block_group_l{}'
.
format
(
i
+
2
))
dilation_rate
*=
2
endpoints
[
str
(
normal_resnet_stage
+
2
)]
=
x
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
DilatedResNet
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
_block_group
(
self
,
inputs
,
filters
,
strides
,
dilation_rate
,
block_fn
,
block_repeats
=
1
,
name
=
'block_group'
):
"""Creates one group of blocks for the ResNet model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
dilation_rate: `int`, diluted convolution rates.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
x
=
block_fn
(
filters
=
filters
,
strides
=
strides
,
dilation_rate
=
dilation_rate
,
use_projection
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
inputs
)
for
_
in
range
(
1
,
block_repeats
):
x
=
block_fn
(
filters
=
filters
,
strides
=
1
,
dilation_rate
=
dilation_rate
,
use_projection
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
x
)
return
tf
.
identity
(
x
,
name
=
name
)
def
get_config
(
self
):
config_dict
=
{
'model_id'
:
self
.
_model_id
,
'output_stride'
:
self
.
_output_stride
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
}
return
config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
):
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
@
factory
.
register_backbone_builder
(
'dilated_resnet'
)
def
build_dilated_resnet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds ResNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'dilated_resnet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
return
DilatedResNet
(
model_id
=
backbone_cfg
.
model_id
,
output_stride
=
backbone_cfg
.
output_stride
,
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
official/vision/beta/modeling/backbones/resnet_deeplab_test.py
0 → 100644
View file @
0168e493
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for resnet_deeplab models."""
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.vision.beta.modeling.backbones
import
resnet_deeplab
class
ResNetTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
128
,
50
,
4
,
8
),
(
128
,
101
,
4
,
8
),
(
128
,
50
,
4
,
16
),
(
128
,
101
,
4
,
16
),
)
def
test_network_creation
(
self
,
input_size
,
model_id
,
endpoint_filter_scale
,
output_stride
):
"""Test creation of ResNet models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
network
=
resnet_deeplab
.
DilatedResNet
(
model_id
=
model_id
,
output_stride
=
output_stride
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
endpoints
=
network
(
inputs
)
print
(
endpoints
)
self
.
assertAllEqual
([
1
,
input_size
/
output_stride
,
input_size
/
output_stride
,
512
*
endpoint_filter_scale
],
endpoints
[
str
(
int
(
np
.
math
.
log2
(
output_stride
)))].
shape
.
as_list
())
@
combinations
.
generate
(
combinations
.
combine
(
strategy
=
[
strategy_combinations
.
tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
use_sync_bn
=
[
False
,
True
],
))
def
test_sync_bn_multiple_devices
(
self
,
strategy
,
use_sync_bn
):
"""Test for sync bn on TPU and GPU devices."""
inputs
=
np
.
random
.
rand
(
64
,
128
,
128
,
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
with
strategy
.
scope
():
network
=
resnet_deeplab
.
DilatedResNet
(
model_id
=
50
,
output_stride
=
8
,
use_sync_bn
=
use_sync_bn
)
_
=
network
(
inputs
)
@
parameterized
.
parameters
(
1
,
3
,
4
)
def
test_input_specs
(
self
,
input_dim
):
"""Test different input feature dimensions."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
input_dim
])
network
=
resnet_deeplab
.
DilatedResNet
(
model_id
=
50
,
output_stride
=
8
,
input_specs
=
input_specs
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
128
,
128
,
input_dim
),
batch_size
=
1
)
_
=
network
(
inputs
)
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
model_id
=
50
,
output_stride
=
8
,
use_sync_bn
=
False
,
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
)
network
=
resnet_deeplab
.
DilatedResNet
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
resnet_deeplab
.
DilatedResNet
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
new_network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/modeling/layers/nn_blocks.py
View file @
0168e493
...
@@ -214,6 +214,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -214,6 +214,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
def
__init__
(
self
,
def
__init__
(
self
,
filters
,
filters
,
strides
,
strides
,
dilation_rate
=
1
,
use_projection
=
False
,
use_projection
=
False
,
stochastic_depth_drop_rate
=
None
,
stochastic_depth_drop_rate
=
None
,
kernel_initializer
=
'VarianceScaling'
,
kernel_initializer
=
'VarianceScaling'
,
...
@@ -231,6 +232,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -231,6 +232,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
the third and final convolution will use 4 times as many filters.
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
downsample the input.
dilation_rate: `int` dilation_rate of convolutions. Default to 1.
use_projection: `bool` for whether this block should use a projection
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
for the first block of a block group, which may change the number of
...
@@ -253,6 +255,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -253,6 +255,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self
.
_filters
=
filters
self
.
_filters
=
filters
self
.
_strides
=
strides
self
.
_strides
=
strides
self
.
_dilation_rate
=
dilation_rate
self
.
_use_projection
=
use_projection
self
.
_use_projection
=
use_projection
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_activation
=
activation
...
@@ -304,6 +307,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -304,6 +307,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
filters
=
self
.
_filters
,
filters
=
self
.
_filters
,
kernel_size
=
3
,
kernel_size
=
3
,
strides
=
self
.
_strides
,
strides
=
self
.
_strides
,
dilation_rate
=
self
.
_dilation_rate
,
padding
=
'same'
,
padding
=
'same'
,
use_bias
=
False
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
...
@@ -339,6 +343,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -339,6 +343,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
config
=
{
config
=
{
'filters'
:
self
.
_filters
,
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'strides'
:
self
.
_strides
,
'dilation_rate'
:
self
.
_dilation_rate
,
'use_projection'
:
self
.
_use_projection
,
'use_projection'
:
self
.
_use_projection
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment