Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7954d4b1
Commit
7954d4b1
authored
Mar 23, 2022
by
Liangzhe Yuan
Committed by
A. Unique TensorFlower
Mar 23, 2022
Browse files
Refactor resnet_3d.
PiperOrigin-RevId: 436815441
parent
53fb1c67
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
57 deletions
+80
-57
official/vision/modeling/backbones/resnet_3d.py
official/vision/modeling/backbones/resnet_3d.py
+80
-57
No files found.
official/vision/modeling/backbones/resnet_3d.py
View file @
7954d4b1
...
@@ -153,19 +153,76 @@ class ResNet3D(tf.keras.Model):
...
@@ -153,19 +153,76 @@ class ResNet3D(tf.keras.Model):
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
self
.
_
bn_axis
=
-
1
else
:
else
:
bn_axis
=
1
self
.
_
bn_axis
=
1
# Build ResNet3D backbone.
# Build ResNet3D backbone.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
endpoints
=
self
.
_build_model
(
inputs
)
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
ResNet3D
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
_build_model
(
self
,
inputs
):
"""Builds model architecture.
Args:
inputs: the keras input spec.
Returns:
endpoints: A dictionary of backbone endpoint features.
"""
# Build stem.
x
=
self
.
_build_stem
(
inputs
,
stem_type
=
self
.
_stem_type
)
temporal_kernel_size
=
1
if
self
.
_stem_pool_temporal_stride
==
1
else
3
x
=
layers
.
MaxPool3D
(
pool_size
=
[
temporal_kernel_size
,
3
,
3
],
strides
=
[
self
.
_stem_pool_temporal_stride
,
2
,
2
],
padding
=
'same'
)(
x
)
# Build intermediate blocks and endpoints.
resnet_specs
=
RESNET_SPECS
[
self
.
_model_id
]
if
len
(
self
.
_temporal_strides
)
!=
len
(
resnet_specs
)
or
len
(
self
.
_temporal_kernel_sizes
)
!=
len
(
resnet_specs
):
raise
ValueError
(
'Number of blocks in temporal specs should equal to resnet_specs.'
)
endpoints
=
{}
for
i
,
resnet_spec
in
enumerate
(
resnet_specs
):
if
resnet_spec
[
0
]
==
'bottleneck3d'
:
block_fn
=
nn_blocks_3d
.
BottleneckBlock3D
else
:
raise
ValueError
(
'Block fn `{}` is not supported.'
.
format
(
resnet_spec
[
0
]))
use_self_gating
=
(
self
.
_use_self_gating
[
i
]
if
self
.
_use_self_gating
else
False
)
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
resnet_spec
[
1
],
temporal_kernel_sizes
=
self
.
_temporal_kernel_sizes
[
i
],
temporal_strides
=
self
.
_temporal_strides
[
i
],
spatial_strides
=
(
1
if
i
==
0
else
2
),
block_fn
=
block_fn
,
block_repeats
=
resnet_spec
[
2
],
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
2
,
5
),
use_self_gating
=
use_self_gating
,
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
+
2
)]
=
x
return
endpoints
def
_build_stem
(
self
,
inputs
,
stem_type
):
"""Builds stem layer."""
# Build stem.
# Build stem.
if
stem_type
==
'v0'
:
if
stem_type
==
'v0'
:
x
=
layers
.
Conv3D
(
x
=
layers
.
Conv3D
(
filters
=
64
,
filters
=
64
,
kernel_size
=
[
stem_conv_temporal_kernel_size
,
7
,
7
],
kernel_size
=
[
self
.
_
stem_conv_temporal_kernel_size
,
7
,
7
],
strides
=
[
stem_conv_temporal_stride
,
2
,
2
],
strides
=
[
self
.
_
stem_conv_temporal_stride
,
2
,
2
],
use_bias
=
False
,
use_bias
=
False
,
padding
=
'same'
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
...
@@ -173,14 +230,15 @@ class ResNet3D(tf.keras.Model):
...
@@ -173,14 +230,15 @@ class ResNet3D(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
inputs
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
self
.
_bn_axis
,
x
)
momentum
=
self
.
_norm_momentum
,
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
elif
stem_type
==
'v1'
:
elif
stem_type
==
'v1'
:
x
=
layers
.
Conv3D
(
x
=
layers
.
Conv3D
(
filters
=
32
,
filters
=
32
,
kernel_size
=
[
stem_conv_temporal_kernel_size
,
3
,
3
],
kernel_size
=
[
self
.
_
stem_conv_temporal_kernel_size
,
3
,
3
],
strides
=
[
stem_conv_temporal_stride
,
2
,
2
],
strides
=
[
self
.
_
stem_conv_temporal_stride
,
2
,
2
],
use_bias
=
False
,
use_bias
=
False
,
padding
=
'same'
,
padding
=
'same'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
...
@@ -188,9 +246,10 @@ class ResNet3D(tf.keras.Model):
...
@@ -188,9 +246,10 @@ class ResNet3D(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
inputs
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
self
.
_bn_axis
,
x
)
momentum
=
self
.
_norm_momentum
,
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
x
=
layers
.
Conv3D
(
x
=
layers
.
Conv3D
(
filters
=
32
,
filters
=
32
,
kernel_size
=
[
1
,
3
,
3
],
kernel_size
=
[
1
,
3
,
3
],
...
@@ -202,9 +261,10 @@ class ResNet3D(tf.keras.Model):
...
@@ -202,9 +261,10 @@ class ResNet3D(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
self
.
_bn_axis
,
x
)
momentum
=
self
.
_norm_momentum
,
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
x
=
layers
.
Conv3D
(
x
=
layers
.
Conv3D
(
filters
=
64
,
filters
=
64
,
kernel_size
=
[
1
,
3
,
3
],
kernel_size
=
[
1
,
3
,
3
],
...
@@ -216,51 +276,14 @@ class ResNet3D(tf.keras.Model):
...
@@ -216,51 +276,14 @@ class ResNet3D(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
self
.
_bn_axis
,
x
)
momentum
=
self
.
_norm_momentum
,
x
=
tf_utils
.
get_activation
(
activation
)(
x
)
epsilon
=
self
.
_norm_epsilon
)(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_activation
)(
x
)
else
:
else
:
raise
ValueError
(
f
'Stem type
{
stem_type
}
not supported.'
)
raise
ValueError
(
f
'Stem type
{
stem_type
}
not supported.'
)
temporal_kernel_size
=
1
if
stem_pool_temporal_stride
==
1
else
3
return
x
x
=
layers
.
MaxPool3D
(
pool_size
=
[
temporal_kernel_size
,
3
,
3
],
strides
=
[
stem_pool_temporal_stride
,
2
,
2
],
padding
=
'same'
)(
x
)
# Build intermediate blocks and endpoints.
resnet_specs
=
RESNET_SPECS
[
model_id
]
if
len
(
temporal_strides
)
!=
len
(
resnet_specs
)
or
len
(
temporal_kernel_sizes
)
!=
len
(
resnet_specs
):
raise
ValueError
(
'Number of blocks in temporal specs should equal to resnet_specs.'
)
endpoints
=
{}
for
i
,
resnet_spec
in
enumerate
(
resnet_specs
):
if
resnet_spec
[
0
]
==
'bottleneck3d'
:
block_fn
=
nn_blocks_3d
.
BottleneckBlock3D
else
:
raise
ValueError
(
'Block fn `{}` is not supported.'
.
format
(
resnet_spec
[
0
]))
x
=
self
.
_block_group
(
inputs
=
x
,
filters
=
resnet_spec
[
1
],
temporal_kernel_sizes
=
temporal_kernel_sizes
[
i
],
temporal_strides
=
temporal_strides
[
i
],
spatial_strides
=
(
1
if
i
==
0
else
2
),
block_fn
=
block_fn
,
block_repeats
=
resnet_spec
[
2
],
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
2
,
5
),
use_self_gating
=
use_self_gating
[
i
]
if
use_self_gating
else
False
,
name
=
'block_group_l{}'
.
format
(
i
+
2
))
endpoints
[
str
(
i
+
2
)]
=
x
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
ResNet3D
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
_block_group
(
self
,
def
_block_group
(
self
,
inputs
:
tf
.
Tensor
,
inputs
:
tf
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment