Commit 6bc2f342 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Fix MoViNet TF Lite state init by replacing '/' with '_' in state names.

PiperOrigin-RevId: 391332984
parent 7a9fc982
......@@ -425,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
self._rezero = Scale(initializer=initializer, name='rezero')
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._frame_count_name = f'{state_prefix}/pos_enc_frame_count'
self._frame_count_name = f'{state_prefix}_pos_enc_frame_count'
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
......@@ -523,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). Expected keys
include `state_prefix + '/pos_enc_frame_count'`.
include `state_prefix + '_pos_enc_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise.
......@@ -587,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/pool_buffer'
self._frame_count_name = f'{state_prefix}/pool_frame_count'
self._state_name = f'{state_prefix}_pool_buffer'
self._frame_count_name = f'{state_prefix}_pool_frame_count'
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
......@@ -611,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/pool_buffer'` and
`state_prefix + '/pool_frame_count'`.
Expected keys include `state_prefix + '__pool_buffer'` and
`state_prefix + '__pool_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise.
......
......@@ -338,7 +338,7 @@ with the Python API:
```python
# Create the interpreter and signature runner
interpreter = tf.lite.Interpreter('/tmp/movinet_a0_stream.tflite')
signature = interpreter.get_signature_runner()
runner = interpreter.get_signature_runner()
# Extract state names and create the initial (zero) states
def state_name(name: str) -> str:
......@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
states = init_states
for clip in clips:
# Input shape: [1, 1, 172, 172, 3]
outputs = signature(**states, image=clip)
outputs = runner(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
```
......
......@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase):
self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
def test_movinet_export_a0_stream_with_tflite(self):
self.skipTest('b/195800800')
saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path
......@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signature = interpreter.get_signature_runner()
runner = interpreter.get_signature_runner('serving_default')
def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')]
......@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
states = init_states
for clip in clips:
outputs = signature(**states, image=clip)
outputs = runner(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
......
......@@ -17,10 +17,10 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
import dataclasses
import math
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
import dataclasses
import tensorflow as tf
from official.modeling import hyperparams
......@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_idx = 1
for block_idx, block in enumerate(self._block_specs):
if isinstance(block, StemSpec):
x, states = movinet_layers.Stem(
layer_obj = movinet_layers.Stem(
block.filters,
block.kernel_size,
block.strides,
......@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/stem',
name='stem')(
x, states=states)
state_prefix='state_stem',
name='stem')
x, states = layer_obj(x, states=states)
endpoints['stem'] = x
elif isinstance(block, MovinetBlockSpec):
if not (len(block.expand_filters) == len(block.kernel_sizes) ==
......@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
self._stochastic_depth_drop_rate * stochastic_depth_idx /
num_layers)
expand_filters, kernel_size, strides = layer
name = f'b{block_idx-1}/l{layer_idx}'
x, states = movinet_layers.MovinetBlock(
name = f'block{block_idx-1}_layer{layer_idx}'
layer_obj = movinet_layers.MovinetBlock(
block.base_filters,
expand_filters,
kernel_size=kernel_size,
......@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
state_prefix=f'state/{name}',
name=name)(
x, states=states)
state_prefix=f'state_{name}',
name=name)
x, states = layer_obj(x, states=states)
endpoints[name] = x
stochastic_depth_idx += 1
elif isinstance(block, HeadSpec):
x, states = movinet_layers.Head(
layer_obj = movinet_layers.Head(
project_filters=block.project_filters,
conv_type=self._conv_type,
activation=self._activation,
......@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/head',
name='head')(
x, states=states)
state_prefix='state_head',
name='head')
x, states = layer_obj(x, states=states)
endpoints['head'] = x
else:
raise ValueError('Unknown block type {}'.format(block))
......@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
for block_idx, block in enumerate(block_specs):
if isinstance(block, StemSpec):
if block.kernel_size[0] > 1:
states['state/stem/stream_buffer'] = (
states['state_stem_stream_buffer'] = (
input_shape[0],
input_shape[1],
divide_resolution(input_shape[2], num_downsamples),
......@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
self._conv_type in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
prefix = f'state_block{block_idx}_layer{layer_idx}'
if kernel_size[0] > 1:
states[f'state/b{block_idx}/l{layer_idx}/stream_buffer'] = (
states[f'{prefix}_stream_buffer'] = (
input_shape[0],
kernel_size[0] - 1,
divide_resolution(input_shape[2], num_downsamples),
......@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
expand_filters,
)
states[f'state/b{block_idx}/l{layer_idx}/pool_buffer'] = (
states[f'{prefix}_pool_buffer'] = (
input_shape[0], 1, 1, 1, expand_filters,
)
states[f'state/b{block_idx}/l{layer_idx}/pool_frame_count'] = (1,)
states[f'{prefix}_pool_frame_count'] = (1,)
if use_positional_encoding:
name = f'state/b{block_idx}/l{layer_idx}/pos_enc_frame_count'
name = f'{prefix}_pos_enc_frame_count'
states[name] = (1,)
if strides[1] != strides[2]:
......@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
self._conv_type not in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
elif isinstance(block, HeadSpec):
states['state/head/pool_buffer'] = (
states['state_head_pool_buffer'] = (
input_shape[0], 1, 1, 1, block.project_filters,
)
states['state/head/pool_frame_count'] = (1,)
states['state_head_pool_frame_count'] = (1,)
return states
......
......@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/stream_buffer'
self._state_name = f'{state_prefix}_stream_buffer'
self._buffer_size = buffer_size
def get_config(self):
......@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/stream_buffer'`.
Expected keys include `state_prefix + '_stream_buffer'`.
Returns:
the output tensor and states
......
......@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints, states = network(inputs)
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['block0_layer0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['block1_layer0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['block2_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block3_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block4_layer0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(states)
......@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints, new_states = backbone({**init_states, 'image': inputs})
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['block0_layer0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['block1_layer0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['block2_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block3_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block4_layer0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(init_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment