Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9ae6996d
Commit
9ae6996d
authored
May 11, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
May 11, 2021
Browse files
Internal change
PiperOrigin-RevId: 373155894
parent
091da63d
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1667 additions
and
0 deletions
+1667
-0
official/vision/beta/projects/movinet/modeling/movinet_layers_test.py
...ion/beta/projects/movinet/modeling/movinet_layers_test.py
+370
-0
official/vision/beta/projects/movinet/modeling/movinet_model.py
...al/vision/beta/projects/movinet/modeling/movinet_model.py
+166
-0
official/vision/beta/projects/movinet/modeling/movinet_model_test.py
...sion/beta/projects/movinet/modeling/movinet_model_test.py
+177
-0
official/vision/beta/projects/movinet/modeling/movinet_test.py
...ial/vision/beta/projects/movinet/modeling/movinet_test.py
+175
-0
official/vision/beta/projects/movinet/movinet_tutorial.ipynb
official/vision/beta/projects/movinet/movinet_tutorial.ipynb
+589
-0
official/vision/beta/projects/movinet/requirements.txt
official/vision/beta/projects/movinet/requirements.txt
+1
-0
official/vision/beta/projects/movinet/train.py
official/vision/beta/projects/movinet/train.py
+94
-0
official/vision/beta/projects/movinet/train_test.py
official/vision/beta/projects/movinet/train_test.py
+95
-0
No files found.
official/vision/beta/projects/movinet/modeling/movinet_layers_test.py
0 → 100644
View file @
9ae6996d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet_layers.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.projects.movinet.modeling
import
movinet_layers
class
MovinetLayersTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_squeeze3d
(
self
):
squeeze
=
movinet_layers
.
Squeeze3D
()
inputs
=
tf
.
ones
([
5
,
1
,
1
,
1
,
3
])
predicted
=
squeeze
(
inputs
)
expected
=
tf
.
ones
([
5
,
3
])
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllEqual
(
predicted
,
expected
)
def
test_mobile_conv2d
(
self
):
conv2d
=
movinet_layers
.
MobileConv2D
(
filters
=
3
,
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'same'
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
use_depthwise
=
False
,
use_temporal
=
False
,
use_buffered_input
=
True
,
)
inputs
=
tf
.
ones
([
1
,
2
,
2
,
2
,
3
])
predicted
=
conv2d
(
inputs
)
expected
=
tf
.
constant
(
[[[[[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
]],
[[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
]]],
[[[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
]],
[[
12.
,
12.
,
12.
],
[
12.
,
12.
,
12.
]]]]])
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
def
test_mobile_conv2d_temporal
(
self
):
conv2d
=
movinet_layers
.
MobileConv2D
(
filters
=
3
,
kernel_size
=
(
3
,
1
),
strides
=
(
1
,
1
),
padding
=
'causal'
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
use_depthwise
=
True
,
use_temporal
=
True
,
use_buffered_input
=
True
,
)
inputs
=
tf
.
ones
([
1
,
2
,
2
,
1
,
3
])
paddings
=
[[
0
,
0
],
[
2
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]]
padded_inputs
=
tf
.
pad
(
inputs
,
paddings
)
predicted
=
conv2d
(
padded_inputs
)
expected
=
tf
.
constant
(
[[[[[
1.
,
1.
,
1.
]],
[[
1.
,
1.
,
1.
]]],
[[[
2.
,
2.
,
2.
]],
[[
2.
,
2.
,
2.
]]]]])
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
def
test_stream_buffer
(
self
):
conv3d_stream
=
nn_layers
.
Conv3D
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
padding
=
'causal'
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
use_buffered_input
=
True
,
)
buffer
=
movinet_layers
.
StreamBuffer
(
buffer_size
=
2
)
conv3d
=
nn_layers
.
Conv3D
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
padding
=
'causal'
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
use_buffered_input
=
False
,
)
inputs
=
tf
.
ones
([
1
,
4
,
2
,
2
,
3
])
expected
=
conv3d
(
inputs
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
buffer
(
frame
,
states
=
states
)
x
=
conv3d_stream
(
x
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
12.
,
12.
,
12.
]]],
[[[
24.
,
24.
,
24.
]]],
[[[
36.
,
36.
,
36.
]]],
[[[
36.
,
36.
,
36.
]]]]])
def
test_stream_conv_block_2plus1d
(
self
):
conv_block
=
movinet_layers
.
ConvBlock
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
activation
=
'relu'
,
conv_type
=
'2plus1d'
,
use_positional_encoding
=
True
,
)
stream_conv_block
=
movinet_layers
.
StreamConvBlock
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
activation
=
'relu'
,
conv_type
=
'2plus1d'
,
use_positional_encoding
=
True
,
)
inputs
=
tf
.
ones
([
1
,
4
,
2
,
2
,
3
])
expected
=
conv_block
(
inputs
)
predicted_disabled
,
_
=
stream_conv_block
(
inputs
)
self
.
assertEqual
(
predicted_disabled
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted_disabled
,
expected
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
stream_conv_block
(
frame
,
states
=
states
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
35.9640400
,
35.9640400
,
35.9640400
]]],
[[[
71.9280700
,
71.9280700
,
71.9280700
]]],
[[[
107.892105
,
107.892105
,
107.892105
]]],
[[[
107.892105
,
107.892105
,
107.892105
]]]]])
def
test_stream_conv_block_3d_2plus1d
(
self
):
conv_block
=
movinet_layers
.
ConvBlock
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
activation
=
'relu'
,
conv_type
=
'3d_2plus1d'
,
use_positional_encoding
=
True
,
)
stream_conv_block
=
movinet_layers
.
StreamConvBlock
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
activation
=
'relu'
,
conv_type
=
'3d_2plus1d'
,
use_positional_encoding
=
True
,
)
inputs
=
tf
.
ones
([
1
,
4
,
2
,
2
,
3
])
expected
=
conv_block
(
inputs
)
predicted_disabled
,
_
=
stream_conv_block
(
inputs
)
self
.
assertEqual
(
predicted_disabled
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted_disabled
,
expected
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
stream_conv_block
(
frame
,
states
=
states
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
35.9640400
,
35.9640400
,
35.9640400
]]],
[[[
71.9280700
,
71.9280700
,
71.9280700
]]],
[[[
107.892105
,
107.892105
,
107.892105
]]],
[[[
107.892105
,
107.892105
,
107.892105
]]]]])
def
test_stream_conv_block
(
self
):
conv_block
=
movinet_layers
.
ConvBlock
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
activation
=
'relu'
,
)
stream_conv_block
=
movinet_layers
.
StreamConvBlock
(
filters
=
3
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
activation
=
'relu'
,
)
inputs
=
tf
.
ones
([
1
,
4
,
2
,
2
,
3
])
expected
=
conv_block
(
inputs
)
predicted_disabled
,
_
=
stream_conv_block
(
inputs
)
self
.
assertEqual
(
predicted_disabled
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted_disabled
,
expected
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
stream_conv_block
(
frame
,
states
=
states
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
11.994005
,
11.994005
,
11.994005
]]],
[[[
23.988010
,
23.988010
,
23.988010
]]],
[[[
35.982014
,
35.982014
,
35.982014
]]],
[[[
35.982014
,
35.982014
,
35.982014
]]]]])
def
test_stream_squeeze_excitation
(
self
):
se
=
movinet_layers
.
StreamSqueezeExcitation
(
3
,
causal
=
True
,
kernel_initializer
=
'ones'
)
inputs
=
tf
.
range
(
4
,
dtype
=
tf
.
float32
)
+
1.
inputs
=
tf
.
reshape
(
inputs
,
[
1
,
4
,
1
,
1
,
1
])
inputs
=
tf
.
tile
(
inputs
,
[
1
,
1
,
2
,
1
,
3
])
expected
,
_
=
se
(
inputs
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
se
(
frame
,
states
=
states
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
self
.
assertAllClose
(
predicted
,
[[[[[
0.9998109
,
0.9998109
,
0.9998109
]],
[[
0.9998109
,
0.9998109
,
0.9998109
]]],
[[[
1.9999969
,
1.9999969
,
1.9999969
]],
[[
1.9999969
,
1.9999969
,
1.9999969
]]],
[[[
3.
,
3.
,
3.
]],
[[
3.
,
3.
,
3.
]]],
[[[
4.
,
4.
,
4.
]],
[[
4.
,
4.
,
4.
]]]]],
1e-5
,
1e-5
)
def
test_stream_movinet_block
(
self
):
block
=
movinet_layers
.
MovinetBlock
(
out_filters
=
3
,
expand_filters
=
6
,
kernel_size
=
(
3
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
causal
=
True
,
)
inputs
=
tf
.
range
(
4
,
dtype
=
tf
.
float32
)
+
1.
inputs
=
tf
.
reshape
(
inputs
,
[
1
,
4
,
1
,
1
,
1
])
inputs
=
tf
.
tile
(
inputs
,
[
1
,
1
,
2
,
1
,
3
])
expected
,
_
=
block
(
inputs
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
predicted
=
[]
for
frame
in
frames
:
x
,
states
=
block
(
frame
,
states
=
states
)
predicted
.
append
(
x
)
predicted
=
tf
.
concat
(
predicted
,
axis
=
1
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
def
test_stream_classifier_head
(
self
):
head
=
movinet_layers
.
Head
(
project_filters
=
5
)
classifier_head
=
movinet_layers
.
ClassifierHead
(
head_filters
=
10
,
num_classes
=
4
)
inputs
=
tf
.
range
(
4
,
dtype
=
tf
.
float32
)
+
1.
inputs
=
tf
.
reshape
(
inputs
,
[
1
,
4
,
1
,
1
,
1
])
inputs
=
tf
.
tile
(
inputs
,
[
1
,
1
,
2
,
1
,
3
])
x
,
_
=
head
(
inputs
)
expected
=
classifier_head
(
x
)
for
num_splits
in
[
1
,
2
,
4
]:
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
]
//
num_splits
,
axis
=
1
)
states
=
{}
for
frame
in
frames
:
x
,
states
=
head
(
frame
,
states
=
states
)
predicted
=
classifier_head
(
x
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/movinet/modeling/movinet_model.py
0 → 100644
View file @
9ae6996d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build Movinet for video classification.
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
from
typing
import
Mapping
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
factory_3d
as
model_factory
from
official.vision.beta.projects.movinet.configs
import
movinet
as
cfg
from
official.vision.beta.projects.movinet.modeling
import
movinet_layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
MovinetClassifier
(
tf
.
keras
.
Model
):
"""A video classification class builder."""
def
__init__
(
self
,
backbone
:
tf
.
keras
.
Model
,
num_classes
:
int
,
input_specs
:
Mapping
[
str
,
tf
.
keras
.
layers
.
InputSpec
]
=
None
,
dropout_rate
:
float
=
0.0
,
kernel_initializer
:
str
=
'HeNormal'
,
kernel_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
bias_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
output_states
:
bool
=
False
,
**
kwargs
):
"""Movinet initialization function.
Args:
backbone: A 3d backbone network.
num_classes: Number of classes in classification task.
input_specs: Specs of the input tensor.
dropout_rate: Rate for dropout regularization.
kernel_initializer: Kernel initializer for the final dense layer.
kernel_regularizer: Kernel regularizer.
bias_regularizer: Bias regularizer.
output_states: if True, output intermediate states that can be used to run
the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream
buffer for streaming video.
**kwargs: Keyword arguments to be passed.
"""
if
not
input_specs
:
input_specs
=
{
'image'
:
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
])
}
self
.
_num_classes
=
num_classes
self
.
_input_specs
=
input_specs
self
.
_dropout_rate
=
dropout_rate
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_output_states
=
output_states
# Keras model variable that excludes @property.setters from tracking
self
.
_self_setattr_tracking
=
False
inputs
=
{
name
:
tf
.
keras
.
Input
(
shape
=
state
.
shape
[
1
:],
name
=
f
'states/
{
name
}
'
)
for
name
,
state
in
input_specs
.
items
()
}
states
=
inputs
.
get
(
'states'
,
{})
endpoints
,
states
=
backbone
(
dict
(
image
=
inputs
[
'image'
],
states
=
states
))
x
=
endpoints
[
'head'
]
x
=
movinet_layers
.
ClassifierHead
(
head_filters
=
backbone
.
_head_filters
,
num_classes
=
num_classes
,
dropout_rate
=
dropout_rate
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
conv_type
=
backbone
.
_conv_type
)(
x
)
if
output_states
:
inputs
[
'states'
]
=
{
k
:
tf
.
keras
.
Input
(
shape
=
v
.
shape
[
1
:],
name
=
k
)
for
k
,
v
in
states
.
items
()
}
outputs
=
(
x
,
states
)
if
output_states
else
x
super
(
MovinetClassifier
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
# Move backbone after super() call so Keras is happy
self
.
_backbone
=
backbone
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
return
dict
(
backbone
=
self
.
backbone
)
@
property
def
backbone
(
self
):
return
self
.
_backbone
def
get_config
(
self
):
config
=
{
'backbone'
:
self
.
_backbone
,
'num_classes'
:
self
.
_num_classes
,
'input_specs'
:
self
.
_input_specs
,
'dropout_rate'
:
self
.
_dropout_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'output_states'
:
self
.
_output_states
,
}
return
config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
# Each InputSpec may need to be deserialized
# This handles the case where we want to load a saved_model loaded with
# `tf.keras.models.load_model`
if
config
[
'input_specs'
]:
for
name
in
config
[
'input_specs'
]:
if
isinstance
(
config
[
'input_specs'
][
name
],
dict
):
config
[
'input_specs'
][
name
]
=
tf
.
keras
.
layers
.
deserialize
(
config
[
'input_specs'
][
name
])
return
cls
(
**
config
)
@
model_factory
.
register_model_builder
(
'movinet'
)
def
build_movinet_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
cfg
.
MovinetModel
,
num_classes
:
int
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds movinet model."""
logging
.
info
(
'Building movinet model with num classes: %s'
,
num_classes
)
if
l2_regularizer
is
not
None
:
logging
.
info
(
'Building movinet model with regularizer: %s'
,
l2_regularizer
.
get_config
())
input_specs_dict
=
{
'image'
:
input_specs
}
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
model_config
.
norm_activation
,
l2_regularizer
=
l2_regularizer
)
model
=
MovinetClassifier
(
backbone
,
num_classes
=
num_classes
,
kernel_regularizer
=
l2_regularizer
,
input_specs
=
input_specs_dict
,
dropout_rate
=
model_config
.
dropout_rate
)
return
model
official/vision/beta/projects/movinet/modeling/movinet_model_test.py
0 → 100644
View file @
9ae6996d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet_model.py."""
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.projects.movinet.modeling
import
movinet
from
official.vision.beta.projects.movinet.modeling
import
movinet_model
class
MovinetModelTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
False
,
True
)
def
test_movinet_classifier_creation
(
self
,
is_training
):
"""Test for creation of a Movinet classifier."""
temporal_size
=
16
spatial_size
=
224
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
temporal_size
,
spatial_size
,
spatial_size
,
3
])
backbone
=
movinet
.
Movinet
(
model_id
=
'a0'
,
input_specs
=
input_specs
)
num_classes
=
1000
model
=
movinet_model
.
MovinetClassifier
(
backbone
=
backbone
,
num_classes
=
num_classes
,
input_specs
=
{
'image'
:
input_specs
},
dropout_rate
=
0.2
)
inputs
=
np
.
random
.
rand
(
2
,
temporal_size
,
spatial_size
,
spatial_size
,
3
)
logits
=
model
(
inputs
,
training
=
is_training
)
self
.
assertAllEqual
([
2
,
num_classes
],
logits
.
shape
)
def
test_movinet_classifier_stream
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
model
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
)
inputs
=
tf
.
ones
([
1
,
5
,
128
,
128
,
3
])
expected_endpoints
,
_
=
model
(
dict
(
image
=
inputs
,
states
=
{}))
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
],
axis
=
1
)
output
,
states
=
None
,
{}
for
frame
in
frames
:
output
,
states
=
model
(
dict
(
image
=
frame
,
states
=
states
))
predicted_endpoints
=
output
predicted
=
predicted_endpoints
[
'head'
]
# The expected final output is simply the mean across frames
expected
=
expected_endpoints
[
'head'
]
expected
=
tf
.
reduce_mean
(
expected
,
1
,
keepdims
=
True
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
def
test_serialize_deserialize
(
self
):
"""Validate the classification network can be serialized and deserialized."""
backbone
=
movinet
.
Movinet
(
model_id
=
'a0'
)
model
=
movinet_model
.
MovinetClassifier
(
backbone
=
backbone
,
num_classes
=
1000
)
config
=
model
.
get_config
()
new_model
=
movinet_model
.
MovinetClassifier
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
new_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
model
.
get_config
(),
new_model
.
get_config
())
def
test_saved_model_save_load
(
self
):
backbone
=
movinet
.
Movinet
(
'a0'
)
model
=
movinet_model
.
MovinetClassifier
(
backbone
,
num_classes
=
600
)
model
.
build
([
1
,
5
,
172
,
172
,
3
])
model
.
compile
(
metrics
=
[
'acc'
])
tf
.
keras
.
models
.
save_model
(
model
,
'/tmp/movinet/'
)
loaded_model
=
tf
.
keras
.
models
.
load_model
(
'/tmp/movinet/'
)
output
=
loaded_model
(
dict
(
image
=
tf
.
ones
([
1
,
1
,
1
,
1
,
3
])))
self
.
assertAllEqual
(
output
.
shape
,
[
1
,
600
])
@
parameterized
.
parameters
(
(
'a0'
,
3.126071
),
(
'a1'
,
4.717912
),
(
'a2'
,
5.280922
),
(
'a3'
,
7.443289
),
(
'a4'
,
11.422727
),
(
'a5'
,
18.763355
),
(
't0'
,
1.740502
),
)
def
test_movinet_models
(
self
,
model_id
,
expected_params_millions
):
"""Test creation of MoViNet family models with states."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
model
=
movinet_model
.
MovinetClassifier
(
backbone
=
movinet
.
Movinet
(
model_id
=
model_id
,
causal
=
True
),
num_classes
=
600
)
model
.
build
([
1
,
1
,
1
,
1
,
3
])
num_params_millions
=
model
.
count_params
()
/
1e6
self
.
assertEqual
(
num_params_millions
,
expected_params_millions
)
def
test_movinet_a0_2plus1d
(
self
):
"""Test creation of MoViNet with 2plus1d configuration."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
model_2plus1d
=
movinet_model
.
MovinetClassifier
(
backbone
=
movinet
.
Movinet
(
model_id
=
'a0'
,
conv_type
=
'2plus1d'
),
num_classes
=
600
)
model_2plus1d
.
build
([
1
,
1
,
1
,
1
,
3
])
model_3d_2plus1d
=
movinet_model
.
MovinetClassifier
(
backbone
=
movinet
.
Movinet
(
model_id
=
'a0'
,
conv_type
=
'3d_2plus1d'
),
num_classes
=
600
)
model_3d_2plus1d
.
build
([
1
,
1
,
1
,
1
,
3
])
# Ensure both models have the same weights
weights
=
[]
for
var_2plus1d
,
var_3d_2plus1d
in
zip
(
model_2plus1d
.
get_weights
(),
model_3d_2plus1d
.
get_weights
()):
if
var_2plus1d
.
shape
==
var_3d_2plus1d
.
shape
:
weights
.
append
(
var_3d_2plus1d
)
else
:
if
var_3d_2plus1d
.
shape
[
0
]
==
1
:
weight
=
var_3d_2plus1d
[
0
]
else
:
weight
=
var_3d_2plus1d
[:,
0
]
if
weight
.
shape
[
-
1
]
!=
var_2plus1d
.
shape
[
-
1
]:
# Transpose any depthwise kernels (conv3d --> depthwise_conv2d)
weight
=
tf
.
transpose
(
weight
,
perm
=
(
0
,
1
,
3
,
2
))
weights
.
append
(
weight
)
model_2plus1d
.
set_weights
(
weights
)
inputs
=
np
.
random
.
rand
(
2
,
8
,
172
,
172
,
3
)
logits_2plus1d
=
model_2plus1d
(
inputs
)
logits_3d_2plus1d
=
model_3d_2plus1d
(
inputs
)
# Ensure both models have the same output, since the weights are the same
self
.
assertAllEqual
(
logits_2plus1d
.
shape
,
logits_3d_2plus1d
.
shape
)
self
.
assertAllClose
(
logits_2plus1d
,
logits_3d_2plus1d
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/movinet/modeling/movinet_test.py
0 → 100644
View file @
9ae6996d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet.py."""
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.movinet.modeling
import
movinet
class
MoViNetTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_network_creation
(
self
):
"""Test creation of MoViNet family models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
network
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
8
,
128
,
128
,
3
),
batch_size
=
1
)
endpoints
,
states
=
network
(
inputs
)
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b0/l0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b1/l0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b2/l0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b3/l0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b4/l0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
states
)
def
test_network_with_states
(
self
):
"""Test creation of MoViNet family models with states."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
network
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
)
inputs
=
tf
.
ones
([
1
,
8
,
128
,
128
,
3
])
_
,
states
=
network
(
inputs
)
endpoints
,
new_states
=
network
(
dict
(
image
=
inputs
,
states
=
states
))
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b0/l0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b1/l0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b2/l0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b3/l0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b4/l0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
states
)
self
.
assertNotEmpty
(
new_states
)
def
test_movinet_stream
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
model
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
)
inputs
=
tf
.
ones
([
1
,
5
,
128
,
128
,
3
])
expected_endpoints
,
_
=
model
(
dict
(
image
=
inputs
,
states
=
{}))
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
],
axis
=
1
)
output
,
states
=
None
,
{}
for
frame
in
frames
:
output
,
states
=
model
(
dict
(
image
=
frame
,
states
=
states
))
predicted_endpoints
=
output
predicted
=
predicted_endpoints
[
'head'
]
# The expected final output is simply the mean across frames
expected
=
expected_endpoints
[
'head'
]
expected
=
tf
.
reduce_mean
(
expected
,
1
,
keepdims
=
True
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
def
test_movinet_2plus1d_stream
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
model
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
conv_type
=
'2plus1d'
,
)
inputs
=
tf
.
ones
([
1
,
5
,
128
,
128
,
3
])
expected_endpoints
,
_
=
model
(
dict
(
image
=
inputs
,
states
=
{}))
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
],
axis
=
1
)
output
,
states
=
None
,
{}
for
frame
in
frames
:
output
,
states
=
model
(
dict
(
image
=
frame
,
states
=
states
))
predicted_endpoints
=
output
predicted
=
predicted_endpoints
[
'head'
]
# The expected final output is simply the mean across frames
expected
=
expected_endpoints
[
'head'
]
expected
=
tf
.
reduce_mean
(
expected
,
1
,
keepdims
=
True
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
def
test_movinet_3d_2plus1d_stream
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
model
=
movinet
.
Movinet
(
model_id
=
'a0'
,
causal
=
True
,
conv_type
=
'3d_2plus1d'
,
)
inputs
=
tf
.
ones
([
1
,
5
,
128
,
128
,
3
])
expected_endpoints
,
_
=
model
(
dict
(
image
=
inputs
,
states
=
{}))
frames
=
tf
.
split
(
inputs
,
inputs
.
shape
[
1
],
axis
=
1
)
output
,
states
=
None
,
{}
for
frame
in
frames
:
output
,
states
=
model
(
dict
(
image
=
frame
,
states
=
states
))
predicted_endpoints
=
output
predicted
=
predicted_endpoints
[
'head'
]
# The expected final output is simply the mean across frames
expected
=
expected_endpoints
[
'head'
]
expected
=
tf
.
reduce_mean
(
expected
,
1
,
keepdims
=
True
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
,
1e-5
,
1e-5
)
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
model_id
=
'a0'
,
causal
=
True
,
use_positional_encoding
=
True
,
)
network
=
movinet
.
Movinet
(
**
kwargs
)
# Create another network object from the first object's config.
new_network
=
movinet
.
Movinet
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
new_network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/movinet/movinet_tutorial.ipynb
0 → 100644
View file @
9ae6996d
This diff is collapsed.
Click to expand it.
official/vision/beta/projects/movinet/requirements.txt
0 → 100644
View file @
9ae6996d
mediapy
official/vision/beta/projects/movinet/train.py
0 → 100644
View file @
9ae6996d
This diff is collapsed.
Click to expand it.
official/vision/beta/projects/movinet/train_test.py
0 → 100644
View file @
9ae6996d
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment