Commit 98a558b7 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

#movinet Add se_type option in tools/convert_3d_2plus1d.py

PiperOrigin-RevId: 420368239
parent d58be675
...@@ -29,6 +29,8 @@ flags.DEFINE_string( ...@@ -29,6 +29,8 @@ flags.DEFINE_string(
'Export path to save the saved_model file.') 'Export path to save the saved_model file.')
flags.DEFINE_string( flags.DEFINE_string(
'model_id', 'a0', 'MoViNet model name.') 'model_id', 'a0', 'MoViNet model name.')
flags.DEFINE_string(
'se_type', '2plus3d', 'MoViNet model SE type.')
flags.DEFINE_bool( flags.DEFINE_bool(
'causal', True, 'Run the model in causal mode.') 'causal', True, 'Run the model in causal mode.')
flags.DEFINE_bool( flags.DEFINE_bool(
...@@ -46,6 +48,7 @@ def main(_) -> None: ...@@ -46,6 +48,7 @@ def main(_) -> None:
backbone_2plus1d = movinet.Movinet( backbone_2plus1d = movinet.Movinet(
model_id=FLAGS.model_id, model_id=FLAGS.model_id,
causal=FLAGS.causal, causal=FLAGS.causal,
se_type=FLAGS.se_type,
conv_type='2plus1d', conv_type='2plus1d',
use_positional_encoding=FLAGS.use_positional_encoding) use_positional_encoding=FLAGS.use_positional_encoding)
model_2plus1d = movinet_model.MovinetClassifier( model_2plus1d = movinet_model.MovinetClassifier(
...@@ -56,6 +59,7 @@ def main(_) -> None: ...@@ -56,6 +59,7 @@ def main(_) -> None:
backbone_3d_2plus1d = movinet.Movinet( backbone_3d_2plus1d = movinet.Movinet(
model_id=FLAGS.model_id, model_id=FLAGS.model_id,
causal=FLAGS.causal, causal=FLAGS.causal,
se_type=FLAGS.se_type,
conv_type='3d_2plus1d', conv_type='3d_2plus1d',
use_positional_encoding=FLAGS.use_positional_encoding) use_positional_encoding=FLAGS.use_positional_encoding)
model_3d_2plus1d = movinet_model.MovinetClassifier( model_3d_2plus1d = movinet_model.MovinetClassifier(
......
...@@ -36,6 +36,7 @@ class Convert3d2plus1dTest(tf.test.TestCase): ...@@ -36,6 +36,7 @@ class Convert3d2plus1dTest(tf.test.TestCase):
model_3d_2plus1d = movinet_model.MovinetClassifier( model_3d_2plus1d = movinet_model.MovinetClassifier(
backbone=movinet.Movinet( backbone=movinet.Movinet(
model_id='a0', model_id='a0',
se_type='2plus3d',
conv_type='3d_2plus1d'), conv_type='3d_2plus1d'),
num_classes=600) num_classes=600)
model_3d_2plus1d.build([1, 1, 1, 1, 3]) model_3d_2plus1d.build([1, 1, 1, 1, 3])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment