Commit 9d1fe069 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Apply stream buffer after the spatial convolution in (2+1)D mode.

PiperOrigin-RevId: 378923791
parent a9d5da28
...@@ -525,7 +525,6 @@ class Movinet(tf.keras.Model): ...@@ -525,7 +525,6 @@ class Movinet(tf.keras.Model):
Returns: Returns:
A dict mapping state names to state shapes. A dict mapping state names to state shapes.
""" """
def divide_resolution(shape, num_downsamples): def divide_resolution(shape, num_downsamples):
"""Downsamples the dimension to calculate strided convolution shape.""" """Downsamples the dimension to calculate strided convolution shape."""
if shape is None: if shape is None:
...@@ -564,6 +563,12 @@ class Movinet(tf.keras.Model): ...@@ -564,6 +563,12 @@ class Movinet(tf.keras.Model):
for layer_idx, layer in enumerate(params): for layer_idx, layer in enumerate(params):
expand_filters, kernel_size, strides = layer expand_filters, kernel_size, strides = layer
# If we use a 2D kernel, we apply spatial downsampling
# before the buffer.
if (tuple(strides[1:3]) != (1, 1) and
self._conv_type in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
if kernel_size[0] > 1: if kernel_size[0] > 1:
states[f'state/b{block_idx}/l{layer_idx}/stream_buffer'] = ( states[f'state/b{block_idx}/l{layer_idx}/stream_buffer'] = (
input_shape[0], input_shape[0],
...@@ -585,7 +590,11 @@ class Movinet(tf.keras.Model): ...@@ -585,7 +590,11 @@ class Movinet(tf.keras.Model):
if strides[1] != strides[2]: if strides[1] != strides[2]:
raise ValueError('Strides must match in the spatial dimensions, ' raise ValueError('Strides must match in the spatial dimensions, '
'got {}'.format(strides)) 'got {}'.format(strides))
if strides[1] != 1 or strides[2] != 1:
# If we use a 3D kernel, we apply spatial downsampling
# after the buffer.
if (tuple(strides[1:3]) != (1, 1) and
self._conv_type not in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1 num_downsamples += 1
elif isinstance(block, HeadSpec): elif isinstance(block, HeadSpec):
states['state/head/pool_buffer'] = ( states['state/head/pool_buffer'] = (
......
...@@ -633,9 +633,28 @@ class StreamConvBlock(ConvBlock): ...@@ -633,9 +633,28 @@ class StreamConvBlock(ConvBlock):
states = dict(states) if states is not None else {} states = dict(states) if states is not None else {}
x = inputs x = inputs
# If we have no separate temporal conv, use the buffer before the 3D conv.
if self._conv_temporal is None and self._stream_buffer is not None:
x, states = self._stream_buffer(x, states=states)
x = self._conv(x)
if self._batch_norm is not None:
x = self._batch_norm(x)
if self._activation_layer is not None:
x = self._activation_layer(x)
if self._conv_temporal is not None:
if self._stream_buffer is not None: if self._stream_buffer is not None:
# If we have a separate temporal conv, use the buffer before the
# 1D conv instead (otherwise, we may waste computation on the 2D conv).
x, states = self._stream_buffer(x, states=states) x, states = self._stream_buffer(x, states=states)
x = super(StreamConvBlock, self).call(x)
x = self._conv_temporal(x)
if self._batch_norm_temporal is not None:
x = self._batch_norm_temporal(x)
if self._activation_layer is not None:
x = self._activation_layer(x)
return x, states return x, states
......
...@@ -115,15 +115,31 @@ class MovinetClassifier(tf.keras.Model): ...@@ -115,15 +115,31 @@ class MovinetClassifier(tf.keras.Model):
inputs = {**states, 'image': image} inputs = {**states, 'image': image}
if backbone.use_external_states: if backbone.use_external_states:
before_states = set(states) before_states = states
endpoints, states = backbone(inputs) endpoints, states = backbone(inputs)
after_states = set(states) after_states = states
new_states = after_states - before_states new_states = set(after_states) - set(before_states)
if new_states: if new_states:
raise AttributeError('Expected input and output states to be the same. ' raise ValueError(
'Got extra states {}, expected {}'.format( 'Expected input and output states to be the same. Got extra states '
new_states, before_states)) '{}, expected {}'.format(new_states, set(before_states)))
mismatched_shapes = {}
for name in after_states:
before_shape = before_states[name].shape
after_shape = after_states[name].shape
if len(before_shape) != len(after_shape):
mismatched_shapes[name] = (before_shape, after_shape)
continue
for before, after in zip(before_shape, after_shape):
if before is not None and after is not None and before != after:
mismatched_shapes[name] = (before_shape, after_shape)
break
if mismatched_shapes:
raise ValueError(
'Got mismatched input and output state shapes: {}'.format(
mismatched_shapes))
else: else:
endpoints, states = backbone(inputs) endpoints, states = backbone(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment