Commit 81ad46bf authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Fix MoViNet TF Lite state init by replacing '/' with '_' in state names.

PiperOrigin-RevId: 391332984
parent cb62fdcc
...@@ -425,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -425,7 +425,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
self._rezero = Scale(initializer=initializer, name='rezero') self._rezero = Scale(initializer=initializer, name='rezero')
state_prefix = state_prefix if state_prefix is not None else '' state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._frame_count_name = f'{state_prefix}/pos_enc_frame_count' self._frame_count_name = f'{state_prefix}_pos_enc_frame_count'
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
...@@ -523,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -523,7 +523,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`. inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). Expected keys layer, will overwrite the contents of the buffer(s). Expected keys
include `state_prefix + '/pos_enc_frame_count'`. include `state_prefix + '_pos_enc_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise. states. Returns just the output tensor otherwise.
...@@ -587,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -587,8 +587,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
state_prefix = state_prefix if state_prefix is not None else '' state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/pool_buffer' self._state_name = f'{state_prefix}_pool_buffer'
self._frame_count_name = f'{state_prefix}/pool_frame_count' self._frame_count_name = f'{state_prefix}_pool_frame_count'
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
...@@ -611,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -611,8 +611,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`. inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/pool_buffer'` and Expected keys include `state_prefix + '__pool_buffer'` and
`state_prefix + '/pool_frame_count'`. `state_prefix + '__pool_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise. states. Returns just the output tensor otherwise.
......
...@@ -338,7 +338,7 @@ with the Python API: ...@@ -338,7 +338,7 @@ with the Python API:
```python ```python
# Create the interpreter and signature runner # Create the interpreter and signature runner
interpreter = tf.lite.Interpreter('/tmp/movinet_a0_stream.tflite') interpreter = tf.lite.Interpreter('/tmp/movinet_a0_stream.tflite')
signature = interpreter.get_signature_runner() runner = interpreter.get_signature_runner()
# Extract state names and create the initial (zero) states # Extract state names and create the initial (zero) states
def state_name(name: str) -> str: def state_name(name: str) -> str:
...@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1) ...@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
states = init_states states = init_states
for clip in clips: for clip in clips:
# Input shape: [1, 1, 172, 172, 3] # Input shape: [1, 1, 172, 172, 3]
outputs = signature(**states, image=clip) outputs = runner(**states, image=clip)
logits = outputs.pop('logits') logits = outputs.pop('logits')
states = outputs states = outputs
``` ```
......
...@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase): ...@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase):
self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5) self.assertAllClose(outputs, expected_outputs, 1e-5, 1e-5)
def test_movinet_export_a0_stream_with_tflite(self): def test_movinet_export_a0_stream_with_tflite(self):
self.skipTest('b/195800800')
saved_model_path = self.get_temp_dir() saved_model_path = self.get_temp_dir()
FLAGS.export_path = saved_model_path FLAGS.export_path = saved_model_path
...@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase): ...@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
tflite_model = converter.convert() tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model) interpreter = tf.lite.Interpreter(model_content=tflite_model)
signature = interpreter.get_signature_runner() runner = interpreter.get_signature_runner('serving_default')
def state_name(name: str) -> str: def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')] return name[len('serving_default_'):-len(':0')]
...@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase): ...@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
states = init_states states = init_states
for clip in clips: for clip in clips:
outputs = signature(**states, image=clip) outputs = runner(**states, image=clip)
logits = outputs.pop('logits') logits = outputs.pop('logits')
states = outputs states = outputs
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf Reference: https://arxiv.org/pdf/2103.11511.pdf
""" """
import dataclasses
import math import math
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model): ...@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_idx = 1 stochastic_depth_idx = 1
for block_idx, block in enumerate(self._block_specs): for block_idx, block in enumerate(self._block_specs):
if isinstance(block, StemSpec): if isinstance(block, StemSpec):
x, states = movinet_layers.Stem( layer_obj = movinet_layers.Stem(
block.filters, block.filters,
block.kernel_size, block.kernel_size,
block.strides, block.strides,
...@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model): ...@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon, batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/stem', state_prefix='state_stem',
name='stem')( name='stem')
x, states=states) x, states = layer_obj(x, states=states)
endpoints['stem'] = x endpoints['stem'] = x
elif isinstance(block, MovinetBlockSpec): elif isinstance(block, MovinetBlockSpec):
if not (len(block.expand_filters) == len(block.kernel_sizes) == if not (len(block.expand_filters) == len(block.kernel_sizes) ==
...@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model): ...@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
self._stochastic_depth_drop_rate * stochastic_depth_idx / self._stochastic_depth_drop_rate * stochastic_depth_idx /
num_layers) num_layers)
expand_filters, kernel_size, strides = layer expand_filters, kernel_size, strides = layer
name = f'b{block_idx-1}/l{layer_idx}' name = f'block{block_idx-1}_layer{layer_idx}'
x, states = movinet_layers.MovinetBlock( layer_obj = movinet_layers.MovinetBlock(
block.base_filters, block.base_filters,
expand_filters, expand_filters,
kernel_size=kernel_size, kernel_size=kernel_size,
...@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model): ...@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon, batch_norm_epsilon=self._norm_epsilon,
state_prefix=f'state/{name}', state_prefix=f'state_{name}',
name=name)( name=name)
x, states=states) x, states = layer_obj(x, states=states)
endpoints[name] = x endpoints[name] = x
stochastic_depth_idx += 1 stochastic_depth_idx += 1
elif isinstance(block, HeadSpec): elif isinstance(block, HeadSpec):
x, states = movinet_layers.Head( layer_obj = movinet_layers.Head(
project_filters=block.project_filters, project_filters=block.project_filters,
conv_type=self._conv_type, conv_type=self._conv_type,
activation=self._activation, activation=self._activation,
...@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model): ...@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm, batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum, batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon, batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/head', state_prefix='state_head',
name='head')( name='head')
x, states=states) x, states = layer_obj(x, states=states)
endpoints['head'] = x endpoints['head'] = x
else: else:
raise ValueError('Unknown block type {}'.format(block)) raise ValueError('Unknown block type {}'.format(block))
...@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model): ...@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
for block_idx, block in enumerate(block_specs): for block_idx, block in enumerate(block_specs):
if isinstance(block, StemSpec): if isinstance(block, StemSpec):
if block.kernel_size[0] > 1: if block.kernel_size[0] > 1:
states['state/stem/stream_buffer'] = ( states['state_stem_stream_buffer'] = (
input_shape[0], input_shape[0],
input_shape[1], input_shape[1],
divide_resolution(input_shape[2], num_downsamples), divide_resolution(input_shape[2], num_downsamples),
...@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model): ...@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
self._conv_type in ['2plus1d', '3d_2plus1d']): self._conv_type in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1 num_downsamples += 1
prefix = f'state_block{block_idx}_layer{layer_idx}'
if kernel_size[0] > 1: if kernel_size[0] > 1:
states[f'state/b{block_idx}/l{layer_idx}/stream_buffer'] = ( states[f'{prefix}_stream_buffer'] = (
input_shape[0], input_shape[0],
kernel_size[0] - 1, kernel_size[0] - 1,
divide_resolution(input_shape[2], num_downsamples), divide_resolution(input_shape[2], num_downsamples),
...@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model): ...@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
expand_filters, expand_filters,
) )
states[f'state/b{block_idx}/l{layer_idx}/pool_buffer'] = ( states[f'{prefix}_pool_buffer'] = (
input_shape[0], 1, 1, 1, expand_filters, input_shape[0], 1, 1, 1, expand_filters,
) )
states[f'state/b{block_idx}/l{layer_idx}/pool_frame_count'] = (1,) states[f'{prefix}_pool_frame_count'] = (1,)
if use_positional_encoding: if use_positional_encoding:
name = f'state/b{block_idx}/l{layer_idx}/pos_enc_frame_count' name = f'{prefix}_pos_enc_frame_count'
states[name] = (1,) states[name] = (1,)
if strides[1] != strides[2]: if strides[1] != strides[2]:
...@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model): ...@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
self._conv_type not in ['2plus1d', '3d_2plus1d']): self._conv_type not in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1 num_downsamples += 1
elif isinstance(block, HeadSpec): elif isinstance(block, HeadSpec):
states['state/head/pool_buffer'] = ( states['state_head_pool_buffer'] = (
input_shape[0], 1, 1, 1, block.project_filters, input_shape[0], 1, 1, 1, block.project_filters,
) )
states['state/head/pool_frame_count'] = (1,) states['state_head_pool_frame_count'] = (1,)
return states return states
......
...@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer): ...@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
state_prefix = state_prefix if state_prefix is not None else '' state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/stream_buffer' self._state_name = f'{state_prefix}_stream_buffer'
self._buffer_size = buffer_size self._buffer_size = buffer_size
def get_config(self): def get_config(self):
...@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer): ...@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
inputs: the input tensor. inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/stream_buffer'`. Expected keys include `state_prefix + '_stream_buffer'`.
Returns: Returns:
the output tensor and states the output tensor and states
......
...@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints, states = network(inputs) endpoints, states = network(inputs)
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8]) self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8]) self.assertAllEqual(endpoints['block0_layer0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32]) self.assertAllEqual(endpoints['block1_layer0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56]) self.assertAllEqual(endpoints['block2_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56]) self.assertAllEqual(endpoints['block3_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104]) self.assertAllEqual(endpoints['block4_layer0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480]) self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(states) self.assertNotEmpty(states)
...@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints, new_states = backbone({**init_states, 'image': inputs}) endpoints, new_states = backbone({**init_states, 'image': inputs})
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8]) self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8]) self.assertAllEqual(endpoints['block0_layer0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32]) self.assertAllEqual(endpoints['block1_layer0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56]) self.assertAllEqual(endpoints['block2_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56]) self.assertAllEqual(endpoints['block3_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104]) self.assertAllEqual(endpoints['block4_layer0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480]) self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(init_states) self.assertNotEmpty(init_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment