Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0509bc4e
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0cc90a8186359a197aa10d1dad361bc81b6ec2b2"
Commit
0509bc4e
authored
Dec 11, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 347135923
parent
68c04c17
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
5 deletions
+65
-5
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+2
-0
official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml
.../image_classification/imagenet_resnet101_deeplab_tpu.yaml
+7
-4
official/vision/beta/modeling/backbones/resnet_deeplab.py
official/vision/beta/modeling/backbones/resnet_deeplab.py
+23
-1
official/vision/beta/modeling/backbones/resnet_deeplab_test.py
...ial/vision/beta/modeling/backbones/resnet_deeplab_test.py
+33
-0
No files found.
official/vision/beta/configs/backbones.py
View file @
0509bc4e
...
...
@@ -40,6 +40,8 @@ class DilatedResNet(hyperparams.Config):
multigrid
:
Optional
[
List
[
int
]]
=
None
stem_type
:
str
=
'v0'
last_stage_repeats
:
int
=
1
se_ratio
:
float
=
0.0
stochastic_depth_drop_rate
:
float
=
0.0
@
dataclasses
.
dataclass
...
...
official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml
View file @
0509bc4e
# Top1 accuracy 8
0.36%
# Top
-
1 accuracy 8
1.6% on ImageNet
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
...
...
@@ -12,12 +12,14 @@ task:
model_id
:
101
output_stride
:
16
stem_type
:
'
v1'
se_ratio
:
0.25
stochastic_depth_drop_rate
:
0.2
multigrid
:
[
1
,
2
,
4
]
last_stage_repeats
:
1
norm_activation
:
activation
:
'
swish'
losses
:
l2_weight_decay
:
0.000
1
l2_weight_decay
:
0.000
04
one_hot
:
true
label_smoothing
:
0.1
train_data
:
...
...
@@ -25,6 +27,7 @@ task:
is_training
:
true
global_batch_size
:
4096
dtype
:
'
bfloat16'
aug_policy
:
'
randaug'
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
...
...
@@ -32,7 +35,7 @@ task:
dtype
:
'
bfloat16'
drop_remainder
:
false
trainer
:
train_steps
:
624
00
train_steps
:
1092
00
validation_steps
:
13
validation_interval
:
312
steps_per_loop
:
312
...
...
@@ -47,7 +50,7 @@ trainer:
type
:
'
cosine'
cosine
:
initial_learning_rate
:
1.6
decay_steps
:
624
00
decay_steps
:
1092
00
warmup
:
type
:
'
linear'
linear
:
...
...
official/vision/beta/modeling/backbones/resnet_deeplab.py
View file @
0509bc4e
...
...
@@ -19,6 +19,7 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_layers
layers
=
tf
.
keras
.
layers
...
...
@@ -57,6 +58,8 @@ class DilatedResNet(tf.keras.Model):
output_stride
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
stem_type
=
'v0'
,
se_ratio
=
None
,
init_stochastic_depth_rate
=
0.0
,
multigrid
=
None
,
last_stage_repeats
=
1
,
activation
=
'relu'
,
...
...
@@ -75,6 +78,8 @@ class DilatedResNet(tf.keras.Model):
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
stem_type: `standard` or `deeplab`, deeplab replaces 7x7 conv by 3 3x3
convs.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: `float` initial stochastic depth rate.
multigrid: `Tuple` of the same length as the number of blocks in the last
resnet stage.
last_stage_repeats: `int`, how many times last stage is repeated.
...
...
@@ -105,6 +110,8 @@ class DilatedResNet(tf.keras.Model):
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_stem_type
=
stem_type
self
.
_se_ratio
=
se_ratio
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
...
...
@@ -193,6 +200,8 @@ class DilatedResNet(tf.keras.Model):
dilation_rate
=
1
,
block_fn
=
block_fn
,
block_repeats
=
spec
[
2
],
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
2
,
4
+
last_stage_repeats
),
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
+
2
)]
=
x
...
...
@@ -210,6 +219,8 @@ class DilatedResNet(tf.keras.Model):
dilation_rate
=
dilation_rate
,
block_fn
=
block_fn
,
block_repeats
=
spec
[
2
],
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
2
,
4
+
last_stage_repeats
),
multigrid
=
multigrid
if
i
>=
3
else
None
,
name
=
'block_group_l{}'
.
format
(
i
+
2
))
dilation_rate
*=
2
...
...
@@ -228,6 +239,7 @@ class DilatedResNet(tf.keras.Model):
dilation_rate
,
block_fn
,
block_repeats
=
1
,
stochastic_depth_drop_rate
=
0.0
,
multigrid
=
None
,
name
=
'block_group'
):
"""Creates one group of blocks for the ResNet model.
...
...
@@ -242,6 +254,7 @@ class DilatedResNet(tf.keras.Model):
dilation_rate: `int`, diluted convolution rates.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
stochastic_depth_drop_rate: `float` drop rate of the current block group.
multigrid: List of ints or None, if specified, dilation rates for each
block is scaled up by its corresponding factor in the multigrid.
name: `str`name for the block.
...
...
@@ -261,6 +274,8 @@ class DilatedResNet(tf.keras.Model):
strides
=
strides
,
dilation_rate
=
dilation_rate
*
multigrid
[
0
],
use_projection
=
True
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
se_ratio
=
self
.
_se_ratio
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
...
...
@@ -275,6 +290,8 @@ class DilatedResNet(tf.keras.Model):
strides
=
1
,
dilation_rate
=
dilation_rate
*
multigrid
[
i
],
use_projection
=
False
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
se_ratio
=
self
.
_se_ratio
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
...
...
@@ -290,6 +307,9 @@ class DilatedResNet(tf.keras.Model):
config_dict
=
{
'model_id'
:
self
.
_model_id
,
'output_stride'
:
self
.
_output_stride
,
'stem_type'
:
self
.
_stem_type
,
'se_ratio'
:
self
.
_se_ratio
,
'init_stochastic_depth_rate'
:
self
.
_init_stochastic_depth_rate
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
...
...
@@ -326,9 +346,11 @@ def build_dilated_resnet(
model_id
=
backbone_cfg
.
model_id
,
output_stride
=
backbone_cfg
.
output_stride
,
input_specs
=
input_specs
,
stem_type
=
backbone_cfg
.
stem_type
,
se_ratio
=
backbone_cfg
.
se_ratio
,
init_stochastic_depth_rate
=
backbone_cfg
.
stochastic_depth_drop_rate
,
multigrid
=
backbone_cfg
.
multigrid
,
last_stage_repeats
=
backbone_cfg
.
last_stage_repeats
,
stem_type
=
backbone_cfg
.
stem_type
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
...
...
official/vision/beta/modeling/backbones/resnet_deeplab_test.py
View file @
0509bc4e
...
...
@@ -48,6 +48,36 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
512
*
endpoint_filter_scale
],
endpoints
[
str
(
int
(
np
.
math
.
log2
(
output_stride
)))].
shape
.
as_list
())
@
parameterized
.
parameters
(
(
'v0'
,
None
,
0.0
),
(
'v1'
,
None
,
0.0
),
(
'v1'
,
0.25
,
0.0
),
(
'v1'
,
0.25
,
0.2
),
)
def
test_network_features
(
self
,
stem_type
,
se_ratio
,
init_stochastic_depth_rate
):
"""Test additional features of ResNet models."""
input_size
=
128
model_id
=
50
endpoint_filter_scale
=
4
output_stride
=
8
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
network
=
resnet_deeplab
.
DilatedResNet
(
model_id
=
model_id
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
se_ratio
=
se_ratio
,
init_stochastic_depth_rate
=
init_stochastic_depth_rate
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
endpoints
=
network
(
inputs
)
print
(
endpoints
)
self
.
assertAllEqual
([
1
,
input_size
/
output_stride
,
input_size
/
output_stride
,
512
*
endpoint_filter_scale
],
endpoints
[
str
(
int
(
np
.
math
.
log2
(
output_stride
)))].
shape
.
as_list
())
@
combinations
.
generate
(
combinations
.
combine
(
strategy
=
[
...
...
@@ -84,6 +114,9 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kwargs
=
dict
(
model_id
=
50
,
output_stride
=
8
,
stem_type
=
'v0'
,
se_ratio
=
0.25
,
init_stochastic_depth_rate
=
0.2
,
use_sync_bn
=
False
,
activation
=
'relu'
,
norm_momentum
=
0.99
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment