Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bc71d8e9
"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "f9ce8e09869d87ba032404a03149e0af0bbd3f27"
Commit
bc71d8e9
authored
Jul 09, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
Jul 09, 2021
Browse files
Internal change
PiperOrigin-RevId: 383861528
parent
cb9aeaf4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
128 additions
and
5 deletions
+128
-5
official/vision/beta/projects/movinet/configs/movinet.py
official/vision/beta/projects/movinet/configs/movinet.py
+5
-0
official/vision/beta/projects/movinet/export_saved_model.py
official/vision/beta/projects/movinet/export_saved_model.py
+7
-0
official/vision/beta/projects/movinet/modeling/movinet.py
official/vision/beta/projects/movinet/modeling/movinet.py
+12
-2
official/vision/beta/projects/movinet/modeling/movinet_layers.py
...l/vision/beta/projects/movinet/modeling/movinet_layers.py
+36
-3
official/vision/beta/projects/movinet/modeling/movinet_layers_test.py
...ion/beta/projects/movinet/modeling/movinet_layers_test.py
+37
-0
official/vision/beta/projects/movinet/modeling/movinet_model_test.py
...sion/beta/projects/movinet/modeling/movinet_model_test.py
+31
-0
No files found.
official/vision/beta/projects/movinet/configs/movinet.py
View file @
bc71d8e9
...
@@ -44,6 +44,11 @@ class Movinet(hyperparams.Config):
...
@@ -44,6 +44,11 @@ class Movinet(hyperparams.Config):
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type
:
str
=
'3d'
conv_type
:
str
=
'3d'
# Choose from ['3d', '2d', '2plus3d']
# 3d: default 3D global average pooling.
# 2d: 2D global average pooling.
# 2plus3d: concatenation of 2D and 3D global average pooling.
se_type
:
str
=
'3d'
activation
:
str
=
'swish'
activation
:
str
=
'swish'
gating_activation
:
str
=
'sigmoid'
gating_activation
:
str
=
'sigmoid'
stochastic_depth_drop_rate
:
float
=
0.2
stochastic_depth_drop_rate
:
float
=
0.2
...
...
official/vision/beta/projects/movinet/export_saved_model.py
View file @
bc71d8e9
...
@@ -53,6 +53,12 @@ flags.DEFINE_string(
...
@@ -53,6 +53,12 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).'
)
'followed by 5x1x1 conv).'
)
flags
.
DEFINE_string
(
'se_type'
,
'3d'
,
'3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global average'
'pooling for squeeze excitation. 2d uses 2D spatial global average pooling '
'on each frame. 2plus3d concatenates both 3D and 2D global average '
'pooling.'
)
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'activation'
,
'swish'
,
'activation'
,
'swish'
,
'The main activation to use across layers.'
)
'The main activation to use across layers.'
)
...
@@ -102,6 +108,7 @@ def main(_) -> None:
...
@@ -102,6 +108,7 @@ def main(_) -> None:
input_specs
=
input_specs
,
input_specs
=
input_specs
,
activation
=
FLAGS
.
activation
,
activation
=
FLAGS
.
activation
,
gating_activation
=
FLAGS
.
gating_activation
,
gating_activation
=
FLAGS
.
gating_activation
,
se_type
=
FLAGS
.
se_type
,
use_positional_encoding
=
FLAGS
.
use_positional_encoding
)
use_positional_encoding
=
FLAGS
.
use_positional_encoding
)
model
=
movinet_model
.
MovinetClassifier
(
model
=
movinet_model
.
MovinetClassifier
(
backbone
,
backbone
,
...
...
official/vision/beta/projects/movinet/modeling/movinet.py
View file @
bc71d8e9
...
@@ -307,6 +307,7 @@ class Movinet(tf.keras.Model):
...
@@ -307,6 +307,7 @@ class Movinet(tf.keras.Model):
causal
:
bool
=
False
,
causal
:
bool
=
False
,
use_positional_encoding
:
bool
=
False
,
use_positional_encoding
:
bool
=
False
,
conv_type
:
str
=
'3d'
,
conv_type
:
str
=
'3d'
,
se_type
:
str
=
'3d'
,
input_specs
:
Optional
[
tf
.
keras
.
layers
.
InputSpec
]
=
None
,
input_specs
:
Optional
[
tf
.
keras
.
layers
.
InputSpec
]
=
None
,
activation
:
str
=
'swish'
,
activation
:
str
=
'swish'
,
gating_activation
:
str
=
'sigmoid'
,
gating_activation
:
str
=
'sigmoid'
,
...
@@ -333,6 +334,10 @@ class Movinet(tf.keras.Model):
...
@@ -333,6 +334,10 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv).
by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
input_specs: the model input spec to use.
input_specs: the model input spec to use.
activation: name of the main activation function.
activation: name of the main activation function.
gating_activation: gating activation to use in squeeze excitation layers.
gating_activation: gating activation to use in squeeze excitation layers.
...
@@ -356,12 +361,15 @@ class Movinet(tf.keras.Model):
...
@@ -356,12 +361,15 @@ class Movinet(tf.keras.Model):
if
conv_type
not
in
(
'3d'
,
'2plus1d'
,
'3d_2plus1d'
):
if
conv_type
not
in
(
'3d'
,
'2plus1d'
,
'3d_2plus1d'
):
raise
ValueError
(
'Unknown conv type: {}'
.
format
(
conv_type
))
raise
ValueError
(
'Unknown conv type: {}'
.
format
(
conv_type
))
if
se_type
not
in
(
'3d'
,
'2d'
,
'2plus3d'
):
raise
ValueError
(
'Unknown squeeze excitation type: {}'
.
format
(
se_type
))
self
.
_model_id
=
model_id
self
.
_model_id
=
model_id
self
.
_block_specs
=
block_specs
self
.
_block_specs
=
block_specs
self
.
_causal
=
causal
self
.
_causal
=
causal
self
.
_use_positional_encoding
=
use_positional_encoding
self
.
_use_positional_encoding
=
use_positional_encoding
self
.
_conv_type
=
conv_type
self
.
_conv_type
=
conv_type
self
.
_se_type
=
se_type
self
.
_input_specs
=
input_specs
self
.
_input_specs
=
input_specs
self
.
_use_sync_bn
=
use_sync_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_activation
=
activation
...
@@ -481,8 +489,9 @@ class Movinet(tf.keras.Model):
...
@@ -481,8 +489,9 @@ class Movinet(tf.keras.Model):
gating_activation
=
self
.
_gating_activation
,
gating_activation
=
self
.
_gating_activation
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
,
conv_type
=
self
.
_conv_type
,
conv_type
=
self
.
_conv_type
,
use_positional_encoding
=
self
.
_use_positional_encoding
and
se_type
=
self
.
_se_type
,
self
.
_causal
,
use_positional_encoding
=
self
.
_use_positional_encoding
and
self
.
_causal
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
...
@@ -695,6 +704,7 @@ def build_movinet(
...
@@ -695,6 +704,7 @@ def build_movinet(
causal
=
backbone_cfg
.
causal
,
causal
=
backbone_cfg
.
causal
,
use_positional_encoding
=
backbone_cfg
.
use_positional_encoding
,
use_positional_encoding
=
backbone_cfg
.
use_positional_encoding
,
conv_type
=
backbone_cfg
.
conv_type
,
conv_type
=
backbone_cfg
.
conv_type
,
se_type
=
backbone_cfg
.
se_type
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
activation
=
backbone_cfg
.
activation
,
activation
=
backbone_cfg
.
activation
,
gating_activation
=
backbone_cfg
.
gating_activation
,
gating_activation
=
backbone_cfg
.
gating_activation
,
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers.py
View file @
bc71d8e9
...
@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
...
@@ -669,6 +669,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_filters
:
int
,
hidden_filters
:
int
,
se_type
:
str
=
'3d'
,
activation
:
nn_layers
.
Activation
=
'swish'
,
activation
:
nn_layers
.
Activation
=
'swish'
,
gating_activation
:
nn_layers
.
Activation
=
'sigmoid'
,
gating_activation
:
nn_layers
.
Activation
=
'sigmoid'
,
causal
:
bool
=
False
,
causal
:
bool
=
False
,
...
@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
...
@@ -683,6 +684,10 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
Args:
Args:
hidden_filters: The hidden filters of squeeze excite.
hidden_filters: The hidden filters of squeeze excite.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
activation: name of the activation function.
activation: name of the activation function.
gating_activation: name of the activation function for gating.
gating_activation: name of the activation function for gating.
causal: if True, use causal mode in the global average pool.
causal: if True, use causal mode in the global average pool.
...
@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
...
@@ -700,6 +705,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
super
(
StreamSqueezeExcitation
,
self
).
__init__
(
**
kwargs
)
super
(
StreamSqueezeExcitation
,
self
).
__init__
(
**
kwargs
)
self
.
_hidden_filters
=
hidden_filters
self
.
_hidden_filters
=
hidden_filters
self
.
_se_type
=
se_type
self
.
_activation
=
activation
self
.
_activation
=
activation
self
.
_gating_activation
=
gating_activation
self
.
_gating_activation
=
gating_activation
self
.
_causal
=
causal
self
.
_causal
=
causal
...
@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
...
@@ -709,8 +715,9 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
self
.
_use_positional_encoding
=
use_positional_encoding
self
.
_use_positional_encoding
=
use_positional_encoding
self
.
_state_prefix
=
state_prefix
self
.
_state_prefix
=
state_prefix
self
.
_pool
=
nn_layers
.
GlobalAveragePool3D
(
self
.
_
spatiotemporal_
pool
=
nn_layers
.
GlobalAveragePool3D
(
keepdims
=
True
,
causal
=
causal
,
state_prefix
=
state_prefix
)
keepdims
=
True
,
causal
=
causal
,
state_prefix
=
state_prefix
)
self
.
_spatial_pool
=
nn_layers
.
SpatialAveragePool3D
(
keepdims
=
True
)
self
.
_pos_encoding
=
None
self
.
_pos_encoding
=
None
if
use_positional_encoding
:
if
use_positional_encoding
:
...
@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
...
@@ -721,6 +728,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
"""Returns a dictionary containing the config used for initialization."""
"""Returns a dictionary containing the config used for initialization."""
config
=
{
config
=
{
'hidden_filters'
:
self
.
_hidden_filters
,
'hidden_filters'
:
self
.
_hidden_filters
,
'se_type'
:
self
.
_se_type
,
'activation'
:
self
.
_activation
,
'activation'
:
self
.
_activation
,
'gating_activation'
:
self
.
_gating_activation
,
'gating_activation'
:
self
.
_gating_activation
,
'causal'
:
self
.
_causal
,
'causal'
:
self
.
_causal
,
...
@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
...
@@ -777,13 +785,28 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
"""
"""
states
=
dict
(
states
)
if
states
is
not
None
else
{}
states
=
dict
(
states
)
if
states
is
not
None
else
{}
x
,
states
=
self
.
_pool
(
inputs
,
states
=
states
)
if
self
.
_se_type
==
'3d'
:
x
,
states
=
self
.
_spatiotemporal_pool
(
inputs
,
states
=
states
)
elif
self
.
_se_type
==
'2d'
:
x
=
self
.
_spatial_pool
(
inputs
)
elif
self
.
_se_type
==
'2plus3d'
:
x_space
=
self
.
_spatial_pool
(
inputs
)
x
,
states
=
self
.
_spatiotemporal_pool
(
x_space
,
states
=
states
)
if
not
self
.
_causal
:
x
=
tf
.
tile
(
x
,
[
1
,
tf
.
shape
(
inputs
)[
1
],
1
,
1
,
1
])
x
=
tf
.
concat
([
x
,
x_space
],
axis
=-
1
)
else
:
raise
ValueError
(
'Unknown Squeeze Excitation type {}'
.
format
(
self
.
_se_type
))
if
self
.
_pos_encoding
is
not
None
:
if
self
.
_pos_encoding
is
not
None
:
x
,
states
=
self
.
_pos_encoding
(
x
,
states
=
states
)
x
,
states
=
self
.
_pos_encoding
(
x
,
states
=
states
)
x
=
self
.
_se_reduce
(
x
)
x
=
self
.
_se_reduce
(
x
)
x
=
self
.
_se_expand
(
x
)
x
=
self
.
_se_expand
(
x
)
return
x
*
inputs
,
states
return
x
*
inputs
,
states
...
@@ -1003,6 +1026,7 @@ class MovinetBlock(tf.keras.layers.Layer):
...
@@ -1003,6 +1026,7 @@ class MovinetBlock(tf.keras.layers.Layer):
se_ratio
:
float
=
0.25
,
se_ratio
:
float
=
0.25
,
stochastic_depth_drop_rate
:
float
=
0.
,
stochastic_depth_drop_rate
:
float
=
0.
,
conv_type
:
str
=
'3d'
,
conv_type
:
str
=
'3d'
,
se_type
:
str
=
'3d'
,
use_positional_encoding
:
bool
=
False
,
use_positional_encoding
:
bool
=
False
,
kernel_initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'HeNormal'
,
kernel_initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'HeNormal'
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
tf
.
keras
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
tf
.
keras
...
@@ -1029,6 +1053,10 @@ class MovinetBlock(tf.keras.layers.Layer):
...
@@ -1029,6 +1053,10 @@ class MovinetBlock(tf.keras.layers.Layer):
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but
uses two sequential 3D ops instead.
uses two sequential 3D ops instead.
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
use_positional_encoding: add a positional encoding after the (cumulative)
use_positional_encoding: add a positional encoding after the (cumulative)
global average pooling layer in the squeeze excite layer.
global average pooling layer in the squeeze excite layer.
kernel_initializer: kernel initializer for the conv operations.
kernel_initializer: kernel initializer for the conv operations.
...
@@ -1044,8 +1072,10 @@ class MovinetBlock(tf.keras.layers.Layer):
...
@@ -1044,8 +1072,10 @@ class MovinetBlock(tf.keras.layers.Layer):
self
.
_kernel_size
=
normalize_tuple
(
kernel_size
,
3
,
'kernel_size'
)
self
.
_kernel_size
=
normalize_tuple
(
kernel_size
,
3
,
'kernel_size'
)
self
.
_strides
=
normalize_tuple
(
strides
,
3
,
'strides'
)
self
.
_strides
=
normalize_tuple
(
strides
,
3
,
'strides'
)
# Use a multiplier of 2 if concatenating multiple features
se_multiplier
=
2
if
se_type
==
'2plus3d'
else
1
se_hidden_filters
=
nn_layers
.
make_divisible
(
se_hidden_filters
=
nn_layers
.
make_divisible
(
se_ratio
*
expand_filters
,
divisor
=
8
)
se_ratio
*
expand_filters
*
se_multiplier
,
divisor
=
8
)
self
.
_out_filters
=
out_filters
self
.
_out_filters
=
out_filters
self
.
_expand_filters
=
expand_filters
self
.
_expand_filters
=
expand_filters
self
.
_kernel_size
=
kernel_size
self
.
_kernel_size
=
kernel_size
...
@@ -1056,6 +1086,7 @@ class MovinetBlock(tf.keras.layers.Layer):
...
@@ -1056,6 +1086,7 @@ class MovinetBlock(tf.keras.layers.Layer):
self
.
_downsample
=
any
(
s
>
1
for
s
in
self
.
_strides
)
self
.
_downsample
=
any
(
s
>
1
for
s
in
self
.
_strides
)
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_conv_type
=
conv_type
self
.
_conv_type
=
conv_type
self
.
_se_type
=
se_type
self
.
_use_positional_encoding
=
use_positional_encoding
self
.
_use_positional_encoding
=
use_positional_encoding
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
...
@@ -1106,6 +1137,7 @@ class MovinetBlock(tf.keras.layers.Layer):
...
@@ -1106,6 +1137,7 @@ class MovinetBlock(tf.keras.layers.Layer):
name
=
'projection'
)
name
=
'projection'
)
self
.
_attention
=
StreamSqueezeExcitation
(
self
.
_attention
=
StreamSqueezeExcitation
(
se_hidden_filters
,
se_hidden_filters
,
se_type
=
se_type
,
activation
=
activation
,
activation
=
activation
,
gating_activation
=
gating_activation
,
gating_activation
=
gating_activation
,
causal
=
self
.
_causal
,
causal
=
self
.
_causal
,
...
@@ -1129,6 +1161,7 @@ class MovinetBlock(tf.keras.layers.Layer):
...
@@ -1129,6 +1161,7 @@ class MovinetBlock(tf.keras.layers.Layer):
'se_ratio'
:
self
.
_se_ratio
,
'se_ratio'
:
self
.
_se_ratio
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'conv_type'
:
self
.
_conv_type
,
'conv_type'
:
self
.
_conv_type
,
'se_type'
:
self
.
_se_type
,
'use_positional_encoding'
:
self
.
_use_positional_encoding
,
'use_positional_encoding'
:
self
.
_use_positional_encoding
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers_test.py
View file @
bc71d8e9
...
@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -314,6 +314,43 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
[[
4.
,
4.
,
4.
]]]]],
[[
4.
,
4.
,
4.
]]]]],
1e-5
,
1e-5
)
1e-5
,
1e-5
)
def
test_stream_squeeze_excitation_2plus3d
(
self
):
se
=
movinet_layers
.
StreamSqueezeExcitation
(
3
,
se_type
=
'2plus3d'
,
causal
=
True
,
activation
=
'hard_swish'
,
gating_activation
=
'hard_sigmoid'
,
kernel_initializer
=
'ones'
)
inputs
=
tf
.
range
(
4
,
dtype
=
tf
.
float32
)
+
1.
inputs
=
tf
.
reshape
(
inputs
,
[
1
,
4
,
1
,
1
,
1
])
inputs
=
tf
.
tile
(
inputs
,
[
1
,
1
,
2
,
1
,
3
])
expected
,
_
=
se
(
inputs
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
se
(
frame
,
states
=
states
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
1.
,
1.
,
1.
]],
[[
1.
,
1.
,
1.
]]],
[[[
2.
,
2.
,
2.
]],
[[
2.
,
2.
,
2.
]]],
[[[
3.
,
3.
,
3.
]],
[[
3.
,
3.
,
3.
]]],
[[[
4.
,
4.
,
4.
]],
[[
4.
,
4.
,
4.
]]]]])
def
test_stream_movinet_block
(
self
):
def
test_stream_movinet_block
(
self
):
block
=
movinet_layers
.
MovinetBlock
(
block
=
movinet_layers
.
MovinetBlock
(
out_filters
=
3
,
out_filters
=
3
,
...
...
official/vision/beta/projects/movinet/modeling/movinet_model_test.py
View file @
bc71d8e9
...
@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -131,6 +131,37 @@ class MovinetModelTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
def
test_movinet_classifier_mobile
(
self
):
"""Test if the model can run with mobile parameters."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
use_external_states
=
True
,
conv_type
=
'2plus1d'
,
se_type
=
'2plus3d'
,
activation
=
'hard_swish'
,
gating_activation
=
'hard_sigmoid'
)
model
=
movinet_model
.
MovinetClassifier
(
backbone
,
num_classes
=
600
,
output_states
=
True
)
inputs
=
tf
.
ones
([
1
,
8
,
172
,
172
,
3
])
init_states
=
model
.
init_states
(
tf
.
shape
(
inputs
))
expected
,
_
=
model
({
**
init_states
,
'image'
:
inputs
})
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
],
axis
=
1
)
states
=
init_states
for
frame
in
frames
:
output
,
states
=
model
({
**
states
,
'image'
:
frame
})
predicted
=
output
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
"""Validate the classification network can be serialized and deserialized."""
"""Validate the classification network can be serialized and deserialized."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment