Commit 9572ecac authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 361061644
parent d2967cfe
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
# ============================================================================== # ==============================================================================
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
from typing import Optional from typing import Any, Callable, Dict, List, Optional, Tuple, Union
# Import libraries
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -24,6 +22,11 @@ import tensorflow as tf ...@@ -24,6 +22,11 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
# Type annotations
States = Dict[str, tf.Tensor]
Activation = Union[str, Callable]
def make_divisible(value: float, def make_divisible(value: float,
divisor: int, divisor: int,
min_value: Optional[float] = None min_value: Optional[float] = None
...@@ -270,3 +273,597 @@ def pyramid_feature_fusion(inputs, target_level): ...@@ -270,3 +273,597 @@ def pyramid_feature_fusion(inputs, target_level):
resampled_feats.append(feat) resampled_feats.append(feat)
return tf.math.add_n(resampled_feats) return tf.math.add_n(resampled_feats)
@tf.keras.utils.register_keras_serializable(package='Vision')
class Scale(tf.keras.layers.Layer):
"""Scales the input by a trainable scalar weight.
Useful for applying ReZero to layers, which improves convergence speed.
Reference: https://arxiv.org/pdf/2003.04887.pdf
"""
def __init__(
self,
initializer: tf.keras.initializers.Initializer = 'ones',
regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes scale layer.
Args:
initializer: initializer for the scalar weight.
regularizer: regularizer for the scalar weight.
**kwargs: keyword arguments to be passed to this layer.
Returns:
A output tensor, which should have the same shape as input.
"""
super(Scale, self).__init__(**kwargs)
self._initializer = initializer
self._regularizer = regularizer
self._scale = self.add_weight(
name='scale',
shape=[],
dtype=self.dtype,
initializer=self._initializer,
regularizer=self._regularizer,
trainable=True)
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'initializer': self._initializer,
'regularizer': self._regularizer,
}
base_config = super(Scale, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Calls the layer with the given inputs."""
scale = tf.cast(self._scale, inputs.dtype)
return scale * inputs
@tf.keras.utils.register_keras_serializable(package='Vision')
class TemporalSoftmaxPool(tf.keras.layers.Layer):
"""Network layer corresponding to temporal softmax pooling.
This is useful for multi-class logits (used in e.g., Charades).
Modified from AssembleNet Charades evaluation.
Reference: https://arxiv.org/pdf/1905.13209.pdf.
"""
def call(self, inputs):
"""Calls the layer with the given inputs."""
assert inputs.shape.rank in (3, 4, 5)
frames = tf.shape(inputs)[1]
pre_logits = inputs / tf.sqrt(tf.cast(frames, inputs.dtype))
activations = tf.nn.softmax(pre_logits, axis=1)
outputs = inputs * activations
return outputs
@tf.keras.utils.register_keras_serializable(package='Vision')
class PositionalEncoding(tf.keras.layers.Layer):
"""Network layer that adds a sinusoidal positional encoding.
Positional encoding is incremented across frames, and is added to the input.
The positional encoding is first weighted at 0 so that the network can choose
to ignore it.
Reference: https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self,
initializer: tf.keras.initializers.Initializer = 'zeros',
cache_encoding: bool = False,
**kwargs):
"""Initializes positional encoding.
Args:
initializer: initializer for weighting the positional encoding.
cache_encoding: if True, cache the positional encoding tensor after
calling build. Otherwise, rebuild the tensor for every call. Setting
this to False can be useful when we want to input a variable number of
frames, so the positional encoding tensor can change shape.
**kwargs: keyword arguments to be passed to this layer.
Returns:
An output tensor, which should have the same shape as input.
"""
super(PositionalEncoding, self).__init__(**kwargs)
self._initializer = initializer
self._cache_encoding = cache_encoding
self._pos_encoding = None
self._rezero = Scale(initializer=initializer, name='rezero')
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'initializer': self._initializer,
'cache_encoding': self._cache_encoding,
}
base_config = super(PositionalEncoding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _positional_encoding(self,
num_positions: int,
hidden_size: int,
dtype: tf.DType = tf.float32):
"""Creates a sequence of sinusoidal positional encoding vectors.
Args:
num_positions: the number of positions (frames).
hidden_size: the number of channels used for the hidden vectors.
dtype: the dtype of the output tensor.
Returns:
The positional encoding tensor with shape [num_positions, hidden_size].
"""
# Calling `tf.range` with `dtype=tf.bfloat16` results in an error,
# so we cast afterward.
positions = tf.cast(tf.range(num_positions)[:, tf.newaxis], dtype)
idx = tf.range(hidden_size)[tf.newaxis, :]
power = tf.cast(2 * (idx // 2), dtype)
power /= tf.cast(hidden_size, dtype)
angles = 1. / tf.math.pow(10_000., power)
radians = positions * angles
sin = tf.math.sin(radians[:, 0::2])
cos = tf.math.cos(radians[:, 1::2])
pos_encoding = tf.concat([sin, cos], axis=-1)
return pos_encoding
def _get_pos_encoding(self, input_shape):
"""Calculates the positional encoding from the input shape."""
frames = input_shape[1]
channels = input_shape[-1]
pos_encoding = self._positional_encoding(frames, channels, dtype=self.dtype)
pos_encoding = tf.reshape(pos_encoding, [1, frames, 1, 1, channels])
return pos_encoding
def build(self, input_shape):
"""Builds the layer with the given input shape.
Args:
input_shape: the input shape.
Raises:
ValueError: if using 'channels_first' data format.
"""
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
if self._cache_encoding:
self._pos_encoding = self._get_pos_encoding(input_shape)
super(PositionalEncoding, self).build(input_shape)
def call(self, inputs):
"""Calls the layer with the given inputs."""
if self._cache_encoding:
pos_encoding = self._pos_encoding
else:
pos_encoding = self._get_pos_encoding(tf.shape(inputs))
pos_encoding = tf.cast(pos_encoding, inputs.dtype)
pos_encoding = tf.stop_gradient(pos_encoding)
pos_encoding = self._rezero(pos_encoding)
return inputs + pos_encoding
@tf.keras.utils.register_keras_serializable(package='Vision')
class GlobalAveragePool3D(tf.keras.layers.Layer):
"""Global average pooling layer with causal mode.
Implements causal mode, which runs a cumulative sum (with `tf.cumsum`) across
frames in the time dimension, allowing the use of a stream buffer. Sums any
valid input state with the current input to allow state to accumulate over
several iterations.
"""
def __init__(self,
keepdims: bool = False,
causal: bool = False,
**kwargs):
"""Initializes global average pool.
Args:
keepdims: if True, keep the averaged dimensions.
causal: run in causal mode with a cumulative sum across frames.
**kwargs: keyword arguments to be passed to this layer.
Returns:
An output tensor.
"""
super(GlobalAveragePool3D, self).__init__(**kwargs)
self._keepdims = keepdims
self._causal = causal
self._frame_count = None
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'keepdims': self._keepdims,
'causal': self._causal,
}
base_config = super(GlobalAveragePool3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
# Here we define strings that will uniquely reference the buffer states
# in the TF graph. These will be used for passing in a mapping of states
# for streaming mode. To do this, we can use a name scope.
with tf.name_scope('buffer') as state_name:
self._state_name = state_name
self._frame_count_name = state_name + '_frame_count'
super(GlobalAveragePool3D, self).build(input_shape)
def call(self,
inputs: tf.Tensor,
states: Optional[States] = None,
output_states: bool = True) -> Union[Any, Tuple[Any, States]]:
"""Calls the layer with the given inputs.
Args:
inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
output_states: if True, returns the output tensor and output states.
Returns just the output tensor otherwise.
Returns:
the output tensor (and optionally the states if `output_states=True`).
If `causal=True`, the output tensor will have shape
`[batch_size, num_frames, 1, 1, channels]` if `keepdims=True`. We keep
the frame dimension in this case to simulate a cumulative global average
as if we are inputting one frame at a time. If `causal=False`, the output
is equivalent to `tf.keras.layers.GlobalAveragePooling3D` with shape
`[batch_size, 1, 1, 1, channels]` if `keepdims=True` (plus the optional
buffer stored in `states`).
Raises:
ValueError: if using 'channels_first' data format.
"""
states = dict(states) if states is not None else {}
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
# Shape: [batch_size, 1, 1, 1, channels]
buffer = states.get(self._state_name, None)
if buffer is None:
buffer = tf.zeros_like(inputs[:, :1, :1, :1], dtype=inputs.dtype)
states[self._state_name] = buffer
# Keep a count of frames encountered across input iterations in
# num_frames to be able to accurately take a cumulative average across
# all frames when running in streaming mode
num_frames = tf.shape(inputs)[1]
frame_count = states.get(self._frame_count_name, 0)
states[self._frame_count_name] = frame_count + num_frames
if self._causal:
# Take a mean of spatial dimensions to make computation more efficient.
x = tf.reduce_mean(inputs, axis=[2, 3], keepdims=True)
x = tf.cumsum(x, axis=1)
x = x + buffer
# The last frame will be the value of the next state
# Shape: [batch_size, 1, 1, 1, channels]
states[self._state_name] = x[:, -1:]
# In causal mode, the divisor increments by 1 for every frame to
# calculate cumulative averages instead of one global average
mean_divisors = tf.range(num_frames) + frame_count + 1
mean_divisors = tf.reshape(mean_divisors, [1, num_frames, 1, 1, 1])
mean_divisors = tf.cast(mean_divisors, x.dtype)
# Shape: [batch_size, num_frames, 1, 1, channels]
x = x / mean_divisors
else:
# In non-causal mode, we (optionally) sum across frames to take a
# cumulative average across input iterations rather than individual
# frames. If no buffer state is passed, this essentially becomes
# regular global average pooling.
# Shape: [batch_size, 1, 1, 1, channels]
x = tf.reduce_sum(inputs, axis=(1, 2, 3), keepdims=True)
x = x / tf.cast(inputs.shape[2] * inputs.shape[3], x.dtype)
x = x + buffer
# Shape: [batch_size, 1, 1, 1, channels]
states[self._state_name] = x
x = x / tf.cast(frame_count + num_frames, x.dtype)
if not self._keepdims:
x = tf.squeeze(x, axis=(1, 2, 3))
return (x, states) if output_states else x
class SpatialAveragePool3D(tf.keras.layers.Layer):
"""Global average pooling layer pooling across spatial dimentions.
"""
def __init__(self, keepdims: bool = False, **kwargs):
"""Initializes global average pool.
Args:
keepdims: if True, keep the averaged dimensions.
**kwargs: keyword arguments to be passed to this layer.
Returns:
An output tensor.
"""
super(SpatialAveragePool3D, self).__init__(**kwargs)
self._keepdims = keepdims
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'keepdims': self._keepdims,
}
base_config = super(SpatialAveragePool3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
super(SpatialAveragePool3D, self).build(input_shape)
def call(self, inputs):
"""Calls the layer with the given inputs."""
if inputs.shape.rank != 5:
raise ValueError(
'Input should have rank {}, got {}'.format(5, inputs.shape.rank))
return tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims)
class CausalConvMixin:
"""Mixin class to implement CausalConv for `tf.keras.layers.Conv` layers."""
@property
def use_buffered_input(self) -> bool:
return self._use_buffered_input
@use_buffered_input.setter
def use_buffered_input(self, variable: bool):
self._use_buffered_input = variable
def _compute_buffered_causal_padding(self,
inputs: Optional[tf.Tensor] = None,
use_buffered_input: bool = False,
time_axis: int = 1) -> List[List[int]]:
"""Calculates padding for 'causal' option for conv layers.
Args:
inputs: optional input tensor to be padded.
use_buffered_input: if True, use 'valid' padding along the time dimension.
This should be set when applying the stream buffer.
time_axis: the axis of the time dimension
Returns:
A list of paddings for `tf.pad`.
"""
del inputs
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
kernel_size_effective = [
(self.kernel_size[i] +
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank)
]
pad_total = [kernel_size_effective[i] - 1 for i in range(self.rank)]
pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
padding = [[0, 0]] + padding + [[0, 0]]
if use_buffered_input:
padding[time_axis] = [0, 0]
else:
padding[time_axis] = [padding[time_axis][0] + padding[time_axis][1], 0]
return padding
def _causal_validate_init(self):
"""Validates the Conv layer initial configuration."""
# Overriding this method is meant to circumvent unnecessary errors when
# using causal padding.
if (self.filters is not None
and self.filters % self.groups != 0):
raise ValueError(
'The number of filters must be evenly divisible by the number of '
'groups. Received: groups={}, filters={}'.format(
self.groups, self.filters))
if not all(self.kernel_size):
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
'Received: %s' % (self.kernel_size,))
def _buffered_spatial_output_shape(self, spatial_output_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
# When buffer padding, use 'valid' padding across time. The output shape
# across time should be the input shape minus any padding, assuming
# the stride across time is 1.
if self._use_buffered_input:
padding = self._compute_buffered_causal_padding(use_buffered_input=False)
spatial_output_shape[0] -= sum(padding[1])
return spatial_output_shape
class Conv2D(tf.keras.layers.Conv2D, CausalConvMixin):
"""Conv2D layer supporting CausalConv.
Supports `padding='causal'` option (like in `tf.keras.layers.Conv1D`),
which applies causal padding to the temporal dimension, and same padding in
the spatial dimensions.
"""
def __init__(self, *args, use_buffered_input=False, **kwargs):
"""Initializes conv2d.
Args:
*args: arguments to be passed.
use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
**kwargs: keyword arguments to be passed.
Returns:
A output tensor of the Conv2D operation.
"""
super(Conv2D, self).__init__(*args, **kwargs)
self._use_buffered_input = use_buffered_input
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'use_buffered_input': self._use_buffered_input,
}
base_config = super(Conv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs."""
return self._compute_buffered_causal_padding(
inputs, use_buffered_input=self._use_buffered_input)
def _validate_init(self):
"""Validates the Conv layer initial configuration."""
self._causal_validate_init()
def _spatial_output_shape(self, spatial_input_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
shape = super(Conv2D, self)._spatial_output_shape(spatial_input_shape)
return self._buffered_spatial_output_shape(shape)
class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, CausalConvMixin):
"""DepthwiseConv2D layer supporting CausalConv.
Supports `padding='causal'` option (like in `tf.keras.layers.Conv1D`),
which applies causal padding to the temporal dimension, and same padding in
the spatial dimensions.
"""
def __init__(self, *args, use_buffered_input=False, **kwargs):
"""Initializes depthwise conv2d.
Args:
*args: arguments to be passed.
use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
**kwargs: keyword arguments to be passed.
Returns:
A output tensor of the DepthwiseConv2D operation.
"""
super(DepthwiseConv2D, self).__init__(*args, **kwargs)
self._use_buffered_input = use_buffered_input
# Causal padding is unsupported by default for DepthwiseConv2D,
# so we resort to valid padding internally. However, we handle
# causal padding as a special case with `self._is_causal`, which is
# defined by the super class.
if self.padding == 'causal':
self.padding = 'valid'
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'use_buffered_input': self._use_buffered_input,
}
base_config = super(DepthwiseConv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Calls the layer with the given inputs."""
if self._is_causal:
inputs = tf.pad(inputs, self._compute_causal_padding(inputs))
return super(DepthwiseConv2D, self).call(inputs)
def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs."""
return self._compute_buffered_causal_padding(
inputs, use_buffered_input=self._use_buffered_input)
def _validate_init(self):
"""Validates the Conv layer initial configuration."""
self._causal_validate_init()
def _spatial_output_shape(self, spatial_input_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
shape = super(DepthwiseConv2D, self)._spatial_output_shape(
spatial_input_shape)
return self._buffered_spatial_output_shape(shape)
class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin):
"""Conv3D layer supporting CausalConv.
Supports `padding='causal'` option (like in `tf.keras.layers.Conv1D`),
which applies causal padding to the temporal dimension, and same padding in
the spatial dimensions.
"""
def __init__(self, *args, use_buffered_input=False, **kwargs):
"""Initializes conv3d.
Args:
*args: arguments to be passed.
use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
**kwargs: keyword arguments to be passed.
Returns:
A output tensor of the Conv3D operation.
"""
super(Conv3D, self).__init__(*args, **kwargs)
self._use_buffered_input = use_buffered_input
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'use_buffered_input': self._use_buffered_input,
}
base_config = super(Conv3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
super(Conv3D, self).build(input_shape)
# TODO(b/177662019): tf.nn.conv3d with depthwise kernels on CPU
# in eager mode may produce incorrect output or cause a segfault.
# To avoid this issue, compile the op to TF graph using tf.function.
self._convolution_op = tf.function(
self._convolution_op, experimental_compile=True)
def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs."""
return self._compute_buffered_causal_padding(
inputs, use_buffered_input=self._use_buffered_input)
def _validate_init(self):
"""Validates the Conv layer initial configuration."""
self._causal_validate_init()
def _spatial_output_shape(self, spatial_input_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
shape = super(Conv3D, self)._spatial_output_shape(spatial_input_shape)
return self._buffered_spatial_output_shape(shape)
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nn_layers."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_scale(self):
scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.))
output = scale(3.)
self.assertAllEqual(output, 30.)
def test_temporal_softmax_pool(self):
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
layer = nn_layers.TemporalSoftmaxPool()
output = layer(inputs)
self.assertAllClose(
output,
[[[[[0.10153633]]],
[[[0.33481020]]],
[[[0.82801306]]],
[[[1.82021690]]]]])
def test_positional_encoding(self):
pos_encoding = nn_layers.PositionalEncoding(
initializer='ones', cache_encoding=False)
pos_encoding_cached = nn_layers.PositionalEncoding(
initializer='ones', cache_encoding=True)
inputs = tf.ones([1, 4, 1, 1, 3])
outputs = pos_encoding(inputs)
outputs_cached = pos_encoding_cached(inputs)
expected = tf.constant(
[[[[[1.0000000, 1.0000000, 2.0000000]]],
[[[1.8414710, 1.0021545, 1.5403023]]],
[[[1.9092975, 1.0043088, 0.5838531]]],
[[[1.1411200, 1.0064633, 0.0100075]]]]])
self.assertEqual(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected)
self.assertEqual(outputs.shape, outputs_cached.shape)
self.assertAllClose(outputs, outputs_cached)
inputs = tf.ones([1, 5, 1, 1, 3])
_ = pos_encoding(inputs)
def test_positional_encoding_bfloat16(self):
pos_encoding = nn_layers.PositionalEncoding(initializer='ones')
inputs = tf.ones([1, 4, 1, 1, 3], dtype=tf.bfloat16)
outputs = pos_encoding(inputs)
expected = tf.constant(
[[[[[1.0000000, 1.0000000, 2.0000000]]],
[[[1.8414710, 1.0021545, 1.5403023]]],
[[[1.9092975, 1.0043088, 0.5838531]]],
[[[1.1411200, 1.0064633, 0.0100075]]]]])
self.assertEqual(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected)
def test_global_average_pool_basic(self):
pool = nn_layers.GlobalAveragePool3D(keepdims=True)
inputs = tf.ones([1, 2, 3, 4, 1])
outputs = pool(inputs, output_states=False)
expected = tf.ones([1, 1, 1, 1, 1])
self.assertEqual(outputs.shape, expected.shape)
self.assertAllEqual(outputs, expected)
def test_global_average_pool_keras(self):
pool = nn_layers.GlobalAveragePool3D(keepdims=False)
keras_pool = tf.keras.layers.GlobalAveragePooling3D()
inputs = 10 * tf.random.normal([1, 2, 3, 4, 1])
outputs = pool(inputs, output_states=False)
keras_output = keras_pool(inputs)
self.assertAllEqual(outputs.shape, keras_output.shape)
self.assertAllClose(outputs, keras_output)
def test_stream_global_average_pool(self):
gap = nn_layers.GlobalAveragePool3D(keepdims=True, causal=False)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = None
for frame in frames:
predicted, states = gap(frame, states=states)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[2.5, 2.5, 2.5]]]]])
def test_causal_stream_global_average_pool(self):
gap = nn_layers.GlobalAveragePool3D(keepdims=True, causal=True)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = gap(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[1.0, 1.0, 1.0]]],
[[[1.5, 1.5, 1.5]]],
[[[2.0, 2.0, 2.0]]],
[[[2.5, 2.5, 2.5]]]]])
def test_spatial_average_pool(self):
pool = nn_layers.SpatialAveragePool3D(keepdims=True)
inputs = tf.range(64, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 4, 4, 1])
output = pool(inputs)
self.assertEqual(output.shape, [1, 4, 1, 1, 1])
self.assertAllClose(
output,
[[[[[8.50]]],
[[[24.5]]],
[[[40.5]]],
[[[56.5]]]]])
def test_conv2d_causal(self):
conv2d = nn_layers.Conv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 2),
padding='causal',
use_buffered_input=True,
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 4, 2, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv2d(padded_inputs)
expected = tf.constant(
[[[[6.0, 6.0, 6.0]],
[[12., 12., 12.]],
[[18., 18., 18.]],
[[18., 18., 18.]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv2d.use_buffered_input = False
predicted = conv2d(inputs)
self.assertFalse(conv2d.use_buffered_input)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_depthwise_conv2d_causal(self):
conv2d = nn_layers.DepthwiseConv2D(
kernel_size=(3, 3),
strides=(1, 1),
padding='causal',
use_buffered_input=True,
depthwise_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 2, 2, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv2d(padded_inputs)
expected = tf.constant(
[[[[2., 2., 2.],
[2., 2., 2.]],
[[4., 4., 4.],
[4., 4., 4.]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv2d.use_buffered_input = False
predicted = conv2d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_conv3d_causal(self):
conv3d = nn_layers.Conv3D(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=True,
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 2, 4, 4, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv3d(padded_inputs)
expected = tf.constant(
[[[[[12., 12., 12.],
[18., 18., 18.]],
[[18., 18., 18.],
[27., 27., 27.]]],
[[[24., 24., 24.],
[36., 36., 36.]],
[[36., 36., 36.],
[54., 54., 54.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv3d.use_buffered_input = False
predicted = conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_depthwise_conv3d_causal(self):
conv3d = nn_layers.Conv3D(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=True,
kernel_initializer='ones',
use_bias=False,
groups=3,
)
inputs = tf.ones([1, 2, 4, 4, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv3d(padded_inputs)
expected = tf.constant(
[[[[[4.0, 4.0, 4.0],
[6.0, 6.0, 6.0]],
[[6.0, 6.0, 6.0],
[9.0, 9.0, 9.0]]],
[[[8.0, 8.0, 8.0],
[12., 12., 12.]],
[[12., 12., 12.],
[18., 18., 18.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv3d.use_buffered_input = False
predicted = conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment