"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "1b2c67af8a035fa90dc1dc507cdd101df3f5a589"
Commit f739ec8d authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 373821916
parent 87fd3922
...@@ -22,7 +22,7 @@ python3 export_saved_model.py \ ...@@ -22,7 +22,7 @@ python3 export_saved_model.py \
--output_path=/tmp/movinet/ \ --output_path=/tmp/movinet/ \
--model_id=a0 \ --model_id=a0 \
--causal=True \ --causal=True \
--use_2plus1d=False \ --conv_type="3d" \
--num_classes=600 \ --num_classes=600 \
--checkpoint_path="" --checkpoint_path=""
``` ```
...@@ -65,8 +65,14 @@ flags.DEFINE_string( ...@@ -65,8 +65,14 @@ flags.DEFINE_string(
'model_id', 'a0', 'MoViNet model name.') 'model_id', 'a0', 'MoViNet model name.')
flags.DEFINE_bool( flags.DEFINE_bool(
'causal', False, 'Run the model in causal mode.') 'causal', False, 'Run the model in causal mode.')
flags.DEFINE_bool( flags.DEFINE_string(
'use_2plus1d', False, 'Use (2+1)D features instead of 3D.') 'conv_type', '3d',
'3d, 2plus1d, or 3d_2plus1d. 3d configures the network '
'to use the default 3D convolution. 2plus1d uses (2+1)D convolution '
'with Conv2D operations and 2D reshaping (e.g., a 5x3x3 kernel becomes '
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).')
flags.DEFINE_integer( flags.DEFINE_integer(
'num_classes', 600, 'The number of classes for prediction.') 'num_classes', 600, 'The number of classes for prediction.')
flags.DEFINE_string( flags.DEFINE_string(
...@@ -86,7 +92,7 @@ def main(argv: Sequence[str]) -> None: ...@@ -86,7 +92,7 @@ def main(argv: Sequence[str]) -> None:
input_shape = [1, 1, 1, 1, 3] input_shape = [1, 1, 1, 1, 3]
backbone = movinet.Movinet( backbone = movinet.Movinet(
FLAGS.model_id, causal=FLAGS.causal, use_2plus1d=FLAGS.use_2plus1d) FLAGS.model_id, causal=FLAGS.causal, conv_type=FLAGS.conv_type)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, num_classes=FLAGS.num_classes, output_states=FLAGS.causal) backbone, num_classes=FLAGS.num_classes, output_states=FLAGS.causal)
model.build(input_shape) model.build(input_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment