Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4fa92552
Commit
4fa92552
authored
Jun 28, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
Jun 28, 2021
Browse files
Add activation function parameters to backbone.
PiperOrigin-RevId: 381908843
parent
f649db2c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
2 deletions
+44
-2
official/vision/beta/modeling/layers/nn_layers.py
official/vision/beta/modeling/layers/nn_layers.py
+17
-0
official/vision/beta/modeling/layers/nn_layers_test.py
official/vision/beta/modeling/layers/nn_layers_test.py
+5
-0
official/vision/beta/projects/movinet/configs/movinet.py
official/vision/beta/projects/movinet/configs/movinet.py
+2
-0
official/vision/beta/projects/movinet/export_saved_model.py
official/vision/beta/projects/movinet/export_saved_model.py
+8
-0
official/vision/beta/projects/movinet/modeling/movinet.py
official/vision/beta/projects/movinet/modeling/movinet.py
+7
-2
official/vision/beta/projects/movinet/modeling/movinet_layers.py
...l/vision/beta/projects/movinet/modeling/movinet_layers.py
+5
-0
No files found.
official/vision/beta/modeling/layers/nn_layers.py
View file @
4fa92552
...
...
@@ -68,6 +68,23 @@ def round_filters(filters: int,
return
int
(
new_filters
)
def
hard_swish
(
x
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""A Swish6/H-Swish activation function.
Reference: Section 5.2 of Howard et al. "Searching for MobileNet V3."
https://arxiv.org/pdf/1905.02244.pdf
Args:
x: the input tensor.
Returns:
The activation output.
"""
return
x
*
tf
.
nn
.
relu6
(
x
+
3.
)
*
(
1.
/
6.
)
tf
.
keras
.
utils
.
get_custom_objects
().
update
({
'hard_swish'
:
hard_swish
})
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SqueezeExcitation
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a squeeze and excitation layer."""
...
...
official/vision/beta/modeling/layers/nn_layers_test.py
View file @
4fa92552
...
...
@@ -24,6 +24,11 @@ from official.vision.beta.modeling.layers import nn_layers
class
NNLayersTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_hard_swish
(
self
):
activation
=
tf
.
keras
.
layers
.
Activation
(
'hard_swish'
)
output
=
activation
(
tf
.
constant
([
-
3
,
-
1.5
,
0
,
3
]))
self
.
assertAllEqual
(
output
,
[
0.
,
-
0.375
,
0.
,
3.
])
def
test_scale
(
self
):
scale
=
nn_layers
.
Scale
(
initializer
=
tf
.
keras
.
initializers
.
constant
(
10.
))
output
=
scale
(
3.
)
...
...
official/vision/beta/projects/movinet/configs/movinet.py
View file @
4fa92552
...
...
@@ -44,6 +44,8 @@ class Movinet(hyperparams.Config):
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type
:
str
=
'3d'
activation
:
str
=
'swish'
gating_activation
:
str
=
'sigmoid'
stochastic_depth_drop_rate
:
float
=
0.2
use_external_states
:
bool
=
False
...
...
official/vision/beta/projects/movinet/export_saved_model.py
View file @
4fa92552
...
...
@@ -53,6 +53,12 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).'
)
flags
.
DEFINE_string
(
'activation'
,
'swish'
,
'The main activation to use across layers.'
)
flags
.
DEFINE_string
(
'gating_activation'
,
'sigmoid'
,
'The gating activation to use in squeeze-excitation layers.'
)
flags
.
DEFINE_bool
(
'use_positional_encoding'
,
False
,
'Whether to use positional encoding (only applied when causal=True).'
)
...
...
@@ -94,6 +100,8 @@ def main(_) -> None:
conv_type
=
FLAGS
.
conv_type
,
use_external_states
=
FLAGS
.
causal
,
input_specs
=
input_specs
,
activation
=
FLAGS
.
activation
,
gating_activation
=
FLAGS
.
gating_activation
,
use_positional_encoding
=
FLAGS
.
use_positional_encoding
)
model
=
movinet_model
.
MovinetClassifier
(
backbone
,
...
...
official/vision/beta/projects/movinet/modeling/movinet.py
View file @
4fa92552
...
...
@@ -309,6 +309,7 @@ class Movinet(tf.keras.Model):
conv_type
:
str
=
'3d'
,
input_specs
:
Optional
[
tf
.
keras
.
layers
.
InputSpec
]
=
None
,
activation
:
str
=
'swish'
,
gating_activation
:
str
=
'sigmoid'
,
use_sync_bn
:
bool
=
True
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
...
...
@@ -333,7 +334,8 @@ class Movinet(tf.keras.Model):
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv).
input_specs: the model input spec to use.
activation: name of the activation function.
activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: normalization momentum for the moving average.
norm_epsilon: small float added to variance to avoid dividing by
...
...
@@ -363,6 +365,7 @@ class Movinet(tf.keras.Model):
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_gating_activation
=
gating_activation
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
use_sync_bn
:
...
...
@@ -475,6 +478,7 @@ class Movinet(tf.keras.Model):
strides
=
strides
,
causal
=
self
.
_causal
,
activation
=
self
.
_activation
,
gating_activation
=
self
.
_gating_activation
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
conv_type
=
self
.
_conv_type
,
use_positional_encoding
=
self
.
_use_positional_encoding
and
...
...
@@ -692,7 +696,8 @@ def build_movinet(
use_positional_encoding
=
backbone_cfg
.
use_positional_encoding
,
conv_type
=
backbone_cfg
.
conv_type
,
input_specs
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
activation
=
backbone_cfg
.
activation
,
gating_activation
=
backbone_cfg
.
gating_activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers.py
View file @
4fa92552
...
...
@@ -999,6 +999,7 @@ class MovinetBlock(tf.keras.layers.Layer):
strides
:
Union
[
int
,
Sequence
[
int
]]
=
(
1
,
1
,
1
),
causal
:
bool
=
False
,
activation
:
nn_layers
.
Activation
=
'swish'
,
gating_activation
:
nn_layers
.
Activation
=
'sigmoid'
,
se_ratio
:
float
=
0.25
,
stochastic_depth_drop_rate
:
float
=
0.
,
conv_type
:
str
=
'3d'
,
...
...
@@ -1021,6 +1022,7 @@ class MovinetBlock(tf.keras.layers.Layer):
strides: strides of the main depthwise convolution.
causal: if True, run the temporal convolutions in causal mode.
activation: activation to use across all conv operations.
gating_activation: gating activation to use in squeeze excitation layers.
se_ratio: squeeze excite filters ratio.
stochastic_depth_drop_rate: optional drop rate for stochastic depth.
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D
...
...
@@ -1049,6 +1051,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self
.
_kernel_size
=
kernel_size
self
.
_causal
=
causal
self
.
_activation
=
activation
self
.
_gating_activation
=
gating_activation
self
.
_se_ratio
=
se_ratio
self
.
_downsample
=
any
(
s
>
1
for
s
in
self
.
_strides
)
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
...
...
@@ -1104,6 +1107,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self
.
_attention
=
StreamSqueezeExcitation
(
se_hidden_filters
,
activation
=
activation
,
gating_activation
=
gating_activation
,
causal
=
self
.
_causal
,
conv_type
=
conv_type
,
use_positional_encoding
=
use_positional_encoding
,
...
...
@@ -1121,6 +1125,7 @@ class MovinetBlock(tf.keras.layers.Layer):
'strides'
:
self
.
_strides
,
'causal'
:
self
.
_causal
,
'activation'
:
self
.
_activation
,
'gating_activation'
:
self
.
_gating_activation
,
'se_ratio'
:
self
.
_se_ratio
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'conv_type'
:
self
.
_conv_type
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment