"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "48a0c17a3d233dd835d8cdd9508f6f4e3d03dcc6"
Commit 106fc83a authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

#movinet Move export_saved_model.py to movinet/tools/ and add support for...

#movinet Move export_saved_model.py to movinet/tools/ and add support for customized classifier activation.

PiperOrigin-RevId: 425669451
parent a6d28318
...@@ -82,6 +82,9 @@ flags.DEFINE_string( ...@@ -82,6 +82,9 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
'activation', 'swish', 'activation', 'swish',
'The main activation to use across layers.') 'The main activation to use across layers.')
flags.DEFINE_string(
'classifier_activation', 'swish',
'The classifier activation to use.')
flags.DEFINE_string( flags.DEFINE_string(
'gating_activation', 'sigmoid', 'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.') 'The gating activation to use in squeeze-excitation layers.')
...@@ -124,11 +127,15 @@ def main(_) -> None: ...@@ -124,11 +127,15 @@ def main(_) -> None:
# states. These dimensions can be set to `None` once the model is built. # states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape] input_shape = [1 if s is None else s for s in input_specs.shape]
# Override swish activation implementation to remove custom gradients
activation = FLAGS.activation activation = FLAGS.activation
if activation == 'swish': if activation == 'swish':
# Override swish activation implementation to remove custom gradients
activation = 'simple_swish' activation = 'simple_swish'
classifier_activation = FLAGS.classifier_activation
if classifier_activation == 'swish':
classifier_activation = 'simple_swish'
backbone = movinet.Movinet( backbone = movinet.Movinet(
model_id=FLAGS.model_id, model_id=FLAGS.model_id,
causal=FLAGS.causal, causal=FLAGS.causal,
...@@ -145,9 +152,7 @@ def main(_) -> None: ...@@ -145,9 +152,7 @@ def main(_) -> None:
num_classes=FLAGS.num_classes, num_classes=FLAGS.num_classes,
output_states=FLAGS.causal, output_states=FLAGS.causal,
input_specs=dict(image=input_specs), input_specs=dict(image=input_specs),
# TODO(dankondratyuk): currently set to swish, but will need to activation=classifier_activation)
# re-train to use other activations.
activation='simple_swish')
model.build(input_shape) model.build(input_shape)
# Compile model to generate some internal Keras variables. # Compile model to generate some internal Keras variables.
......
...@@ -18,7 +18,7 @@ from absl import flags ...@@ -18,7 +18,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.projects.movinet import export_saved_model from official.projects.movinet.tools import export_saved_model
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment