Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
98a558b7
Commit
98a558b7
authored
Jan 07, 2022
by
Liangzhe Yuan
Committed by
A. Unique TensorFlower
Jan 07, 2022
Browse files
#movinet Add se_type option in tools/convert_3d_2plus1d.py
PiperOrigin-RevId: 420368239
parent
d58be675
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
0 deletions
+5
-0
official/vision/beta/projects/movinet/tools/convert_3d_2plus1d.py
.../vision/beta/projects/movinet/tools/convert_3d_2plus1d.py
+4
-0
official/vision/beta/projects/movinet/tools/convert_3d_2plus1d_test.py
...on/beta/projects/movinet/tools/convert_3d_2plus1d_test.py
+1
-0
No files found.
official/vision/beta/projects/movinet/tools/convert_3d_2plus1d.py
View file @
98a558b7
...
...
@@ -29,6 +29,8 @@ flags.DEFINE_string(
'Export path to save the saved_model file.'
)
flags
.
DEFINE_string
(
'model_id'
,
'a0'
,
'MoViNet model name.'
)
flags
.
DEFINE_string
(
'se_type'
,
'2plus3d'
,
'MoViNet model SE type.'
)
flags
.
DEFINE_bool
(
'causal'
,
True
,
'Run the model in causal mode.'
)
flags
.
DEFINE_bool
(
...
...
@@ -46,6 +48,7 @@ def main(_) -> None:
backbone_2plus1d
=
movinet
.
Movinet
(
model_id
=
FLAGS
.
model_id
,
causal
=
FLAGS
.
causal
,
se_type
=
FLAGS
.
se_type
,
conv_type
=
'2plus1d'
,
use_positional_encoding
=
FLAGS
.
use_positional_encoding
)
model_2plus1d
=
movinet_model
.
MovinetClassifier
(
...
...
@@ -56,6 +59,7 @@ def main(_) -> None:
backbone_3d_2plus1d
=
movinet
.
Movinet
(
model_id
=
FLAGS
.
model_id
,
causal
=
FLAGS
.
causal
,
se_type
=
FLAGS
.
se_type
,
conv_type
=
'3d_2plus1d'
,
use_positional_encoding
=
FLAGS
.
use_positional_encoding
)
model_3d_2plus1d
=
movinet_model
.
MovinetClassifier
(
...
...
official/vision/beta/projects/movinet/tools/convert_3d_2plus1d_test.py
View file @
98a558b7
...
...
@@ -36,6 +36,7 @@ class Convert3d2plus1dTest(tf.test.TestCase):
model_3d_2plus1d
=
movinet_model
.
MovinetClassifier
(
backbone
=
movinet
.
Movinet
(
model_id
=
'a0'
,
se_type
=
'2plus3d'
,
conv_type
=
'3d_2plus1d'
),
num_classes
=
600
)
model_3d_2plus1d
.
build
([
1
,
1
,
1
,
1
,
3
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment