Commit c633f2c8 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 378423112
parent 5fd25faa
...@@ -19,38 +19,18 @@ Export example: ...@@ -19,38 +19,18 @@ Export example:
```shell ```shell
python3 export_saved_model.py \ python3 export_saved_model.py \
--output_path=/tmp/movinet/ \ --export_path=/tmp/movinet/ \
--model_id=a0 \ --model_id=a0 \
--causal=True \ --causal=True \
--conv_type="3d" \ --conv_type="3d" \
--num_classes=600 \ --num_classes=600 \
--use_positional_encoding=False \
--checkpoint_path="" --checkpoint_path=""
``` ```
To use an exported saved_model in various applications: To use an exported saved_model, refer to export_saved_model_test.py.
```python
import tensorflow as tf
import tensorflow_hub as hub
saved_model_path = ...
inputs = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32)
encoder = hub.KerasLayer(saved_model_path, trainable=True)
outputs = encoder(inputs)
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input, states)
```
""" """
from typing import Sequence
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
...@@ -59,8 +39,8 @@ from official.vision.beta.projects.movinet.modeling import movinet ...@@ -59,8 +39,8 @@ from official.vision.beta.projects.movinet.modeling import movinet
from official.vision.beta.projects.movinet.modeling import movinet_model from official.vision.beta.projects.movinet.modeling import movinet_model
flags.DEFINE_string( flags.DEFINE_string(
'output_path', '/tmp/movinet/', 'export_path', '/tmp/movinet/',
'Path to saved exported saved_model file.') 'Export path to save the saved_model file.')
flags.DEFINE_string( flags.DEFINE_string(
'model_id', 'a0', 'MoViNet model name.') 'model_id', 'a0', 'MoViNet model name.')
flags.DEFINE_bool( flags.DEFINE_bool(
...@@ -73,8 +53,20 @@ flags.DEFINE_string( ...@@ -73,8 +53,20 @@ flags.DEFINE_string(
'3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with ' '3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with '
'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 ' 'Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 '
'followed by 5x1x1 conv).') 'followed by 5x1x1 conv).')
flags.DEFINE_bool(
'use_positional_encoding', False,
'Whether to use positional encoding (only applied when causal=True).')
flags.DEFINE_integer( flags.DEFINE_integer(
'num_classes', 600, 'The number of classes for prediction.') 'num_classes', 600, 'The number of classes for prediction.')
flags.DEFINE_integer(
'batch_size', None,
'The batch size of the input. Set to None for dynamic input.')
flags.DEFINE_integer(
'num_frames', None,
'The number of frames of the input. Set to None for dynamic input.')
flags.DEFINE_integer(
'image_size', None,
'The resolution of the input. Set to None for dynamic input.')
flags.DEFINE_string( flags.DEFINE_string(
'checkpoint_path', '', 'checkpoint_path', '',
'Checkpoint path to load. Leave blank for default initialization.') 'Checkpoint path to load. Leave blank for default initialization.')
...@@ -82,75 +74,79 @@ flags.DEFINE_string( ...@@ -82,75 +74,79 @@ flags.DEFINE_string(
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def main(argv: Sequence[str]) -> None: def main(_) -> None:
if len(argv) > 1: input_specs = tf.keras.layers.InputSpec(shape=[
raise app.UsageError('Too many command-line arguments.') FLAGS.batch_size,
FLAGS.num_frames,
FLAGS.image_size,
FLAGS.image_size,
3,
])
# Use dimensions of 1 except the channels to export faster, # Use dimensions of 1 except the channels to export faster,
# since we only really need the last dimension to build and get the output # since we only really need the last dimension to build and get the output
# states. These dimensions will be set to `None` once the model is built. # states. These dimensions will be set to `None` once the model is built.
input_shape = [1, 1, 1, 1, 3] input_shape = [1 if s is None else s for s in input_specs.shape]
backbone = movinet.Movinet( backbone = movinet.Movinet(
FLAGS.model_id, causal=FLAGS.causal, conv_type=FLAGS.conv_type) FLAGS.model_id,
causal=FLAGS.causal,
conv_type=FLAGS.conv_type,
use_external_states=FLAGS.causal,
input_specs=input_specs,
use_positional_encoding=FLAGS.use_positional_encoding)
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, num_classes=FLAGS.num_classes, output_states=FLAGS.causal) backbone,
num_classes=FLAGS.num_classes,
output_states=FLAGS.causal,
input_specs=dict(image=input_specs))
model.build(input_shape) model.build(input_shape)
# Compile model to generate some internal Keras variables.
model.compile()
if FLAGS.checkpoint_path: if FLAGS.checkpoint_path:
model.load_weights(FLAGS.checkpoint_path) checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(FLAGS.checkpoint_path)
status.assert_existing_objects_matched()
if FLAGS.causal: if FLAGS.causal:
# Call the model once to get the output states. Call again with `states` # Call the model once to get the output states. Call again with `states`
# input to ensure that the inputs with the `states` argument is built # input to ensure that the inputs with the `states` argument is built
_, states = model(dict(image=tf.ones(input_shape), states={})) # with the full output state shapes.
_, states = model(dict(image=tf.ones(input_shape), states=states)) input_image = tf.ones(input_shape)
_, states = model({**model.init_states(input_shape), 'image': input_image})
input_spec = tf.TensorSpec( _, states = model({**states, 'image': input_image})
shape=[None, None, None, None, 3],
dtype=tf.float32, # Create a function to explicitly set the names of the outputs
name='inputs') def predict(inputs):
outputs, states = model(inputs)
state_specs = {} return {**states, 'logits': outputs}
for name, state in states.items():
shape = state.shape specs = {
if len(state.shape) == 5: name: tf.TensorSpec(spec.shape, name=name, dtype=spec.dtype)
shape = [None, state.shape[1], None, None, state.shape[-1]] for name, spec in model.initial_state_specs(
new_spec = tf.TensorSpec(shape=shape, dtype=state.dtype, name=name) input_specs.shape).items()
state_specs[name] = new_spec }
specs['image'] = tf.TensorSpec(
specs = (input_spec, state_specs) input_specs.shape, dtype=model.dtype, name='image')
# Define a tf.keras.Model with custom signatures to allow it to accept predict_fn = tf.function(predict, jit_compile=True)
# a state dict as an argument. We define it inline here because predict_fn = predict_fn.get_concrete_function(specs)
# we first need to determine the shape of the state tensors before
# applying the `input_signature` argument to `tf.function`. init_states_fn = tf.function(model.init_states, jit_compile=True)
class ExportStateModule(tf.Module): init_states_fn = init_states_fn.get_concrete_function(
"""Module with state for exporting to saved_model.""" tf.TensorSpec([5], dtype=tf.int32))
def __init__(self, model): signatures = {'call': predict_fn, 'init_states': init_states_fn}
self.model = model
tf.keras.models.save_model(
@tf.function(input_signature=[input_spec]) model, FLAGS.export_path, signatures=signatures)
def __call__(self, inputs):
return self.model(dict(image=inputs, states={}))
@tf.function(input_signature=[input_spec])
def base(self, inputs):
return self.model(dict(image=inputs, states={}))
@tf.function(input_signature=specs)
def stream(self, inputs, states):
return self.model(dict(image=inputs, states=states))
module = ExportStateModule(model)
tf.saved_model.save(module, FLAGS.output_path)
else: else:
_ = model(tf.ones(input_shape)) _ = model(tf.ones(input_shape))
tf.keras.models.save_model(model, FLAGS.output_path) tf.keras.models.save_model(model, FLAGS.export_path)
print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.output_path)) print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.export_path))
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_saved_model."""
from absl import flags
import tensorflow as tf
import tensorflow_hub as hub
from official.vision.beta.projects.movinet import export_saved_model
FLAGS = flags.FLAGS
class ExportSavedModelTest(tf.test.TestCase):
def test_movinet_export_a0_base_with_tfhub(self):
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
FLAGS.model_id = 'a0'
FLAGS.causal = False
FLAGS.num_classes = 600
export_saved_model.main('unused_args')
encoder = hub.KerasLayer(saved_model_path, trainable=True)
inputs = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32)
outputs = encoder(dict(image=inputs))
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
outputs = model(example_input)
self.assertEqual(outputs.shape, [1, 600])
def test_movinet_export_a0_stream_with_tfhub(self):
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
FLAGS.model_id = 'a0'
FLAGS.causal = True
FLAGS.num_classes = 600
export_saved_model.main('unused_args')
encoder = hub.KerasLayer(saved_model_path, trainable=True)
image_input = tf.keras.layers.Input(
shape=[None, None, None, 3],
dtype=tf.float32,
name='image')
init_states_fn = encoder.resolved_object.signatures['init_states']
state_shapes = {
name: ([s if s > 0 else None for s in state.shape], state.dtype)
for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()
}
states_input = {
name: tf.keras.Input(shape[1:], dtype=dtype, name=name)
for name, (shape, dtype) in state_shapes.items()
}
inputs = {**states_input, 'image': image_input}
outputs = encoder(inputs)
model = tf.keras.Model(inputs, outputs)
example_input = tf.ones([1, 8, 172, 172, 3])
frames = tf.split(example_input, example_input.shape[1], axis=1)
init_states = init_states_fn(tf.shape(example_input))
expected_outputs, _ = model({**init_states, 'image': example_input})
states = init_states
for frame in frames:
outputs, states = model({**states, 'image': frame})
self.assertEqual(outputs.shape, [1, 600])
self.assertNotEmpty(states)
self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment