Commit bb124157 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 2e9bb539 0edeb7f6
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mask sampler.""" """Contains definitions of mask sampler."""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -30,34 +30,34 @@ def _sample_and_crop_foreground_masks(candidate_rois, ...@@ -30,34 +30,34 @@ def _sample_and_crop_foreground_masks(candidate_rois,
"""Samples and creates cropped foreground masks for training. """Samples and creates cropped foreground masks for training.
Args: Args:
candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the candidate_rois: A `tf.Tensor` of shape of [batch_size, N, 4], where N is the
number of candidate RoIs to be considered for mask sampling. It includes number of candidate RoIs to be considered for mask sampling. It includes
both positive and negative RoIs. The `num_mask_samples_per_image` positive both positive and negative RoIs. The `num_mask_samples_per_image` positive
RoIs will be sampled to create mask training targets. RoIs will be sampled to create mask training targets.
candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the candidate_gt_boxes: A `tf.Tensor` of shape of [batch_size, N, 4], storing
corresponding groundtruth boxes to the `candidate_rois`. the corresponding groundtruth boxes to the `candidate_rois`.
candidate_gt_classes: a tensor of shape of [batch_size, N], storing the candidate_gt_classes: A `tf.Tensor` of shape of [batch_size, N], storing the
corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor
corresponds to the background class, i.e. negative RoIs. corresponds to the background class, i.e. negative RoIs.
candidate_gt_indices: a tensor of shape [batch_size, N], storing the candidate_gt_indices: A `tf.Tensor` of shape [batch_size, N], storing the
corresponding groundtruth instance indices to the `candidate_gt_boxes`, corresponding groundtruth instance indices to the `candidate_gt_boxes`,
i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and
gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is the gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is
superset of candidate_gt_boxes. the superset of candidate_gt_boxes.
gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width] gt_masks: A `tf.Tensor` of [batch_size, MAX_INSTANCES, mask_height,
containing all the groundtruth masks which sample masks are drawn from. mask_width] containing all the groundtruth masks which sample masks are
num_sampled_masks: an integer which specifies the number of masks drawn from.
to sample. num_sampled_masks: An `int` that specifies the number of masks to sample.
mask_target_size: an integer which specifies the final cropped mask size mask_target_size: An `int` that specifies the final cropped mask size after
after sampling. The output masks are resized w.r.t the sampled RoIs. sampling. The output masks are resized w.r.t the sampled RoIs.
Returns: Returns:
foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI foreground_rois: A `tf.Tensor` of shape of [batch_size, K, 4] storing the
that corresponds to the sampled foreground masks, where RoI that corresponds to the sampled foreground masks, where
K = num_mask_samples_per_image. K = num_mask_samples_per_image.
foreground_classes: a tensor of shape of [batch_size, K] storing the classes foreground_classes: A `tf.Tensor` of shape of [batch_size, K] storing the
corresponding to the sampled foreground masks. classes corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of cropoped_foreground_masks: A `tf.Tensor` of shape of
[batch_size, K, mask_target_size, mask_target_size] storing the cropped [batch_size, K, mask_target_size, mask_target_size] storing the cropped
foreground masks used for training. foreground masks used for training.
""" """
...@@ -120,34 +120,36 @@ class MaskSampler(tf.keras.layers.Layer): ...@@ -120,34 +120,36 @@ class MaskSampler(tf.keras.layers.Layer):
candidate_gt_classes, candidate_gt_classes,
candidate_gt_indices, candidate_gt_indices,
gt_masks): gt_masks):
"""Sample and create mask targets for training. """Samples and creates mask targets for training.
Args: Args:
candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the candidate_rois: A `tf.Tensor` of shape of [batch_size, N, 4], where N is
number of candidate RoIs to be considered for mask sampling. It includes the number of candidate RoIs to be considered for mask sampling. It
both positive and negative RoIs. The `num_mask_samples_per_image` includes both positive and negative RoIs. The
positive RoIs will be sampled to create mask training targets. `num_mask_samples_per_image` positive RoIs will be sampled to create
candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the mask training targets.
corresponding groundtruth boxes to the `candidate_rois`. candidate_gt_boxes: A `tf.Tensor` of shape of [batch_size, N, 4], storing
candidate_gt_classes: a tensor of shape of [batch_size, N], storing the the corresponding groundtruth boxes to the `candidate_rois`.
corresponding groundtruth classes to the `candidate_rois`. 0 in the candidate_gt_classes: A `tf.Tensor` of shape of [batch_size, N], storing
the corresponding groundtruth classes to the `candidate_rois`. 0 in the
tensor corresponds to the background class, i.e. negative RoIs. tensor corresponds to the background class, i.e. negative RoIs.
candidate_gt_indices: a tensor of shape [batch_size, N], storing the candidate_gt_indices: A `tf.Tensor` of shape [batch_size, N], storing the
corresponding groundtruth instance indices to the `candidate_gt_boxes`, corresponding groundtruth instance indices to the `candidate_gt_boxes`,
i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i], i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i],
where gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, where gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >=
is the superset of candidate_gt_boxes. N, is the superset of candidate_gt_boxes.
gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width] gt_masks: A `tf.Tensor` of [batch_size, MAX_INSTANCES, mask_height,
containing all the groundtruth masks which sample masks are drawn from. mask_width] containing all the groundtruth masks which sample masks are
after sampling. The output masks are resized w.r.t the sampled RoIs. drawn from. after sampling. The output masks are resized w.r.t the
sampled RoIs.
Returns: Returns:
foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI foreground_rois: A `tf.Tensor` of shape of [batch_size, K, 4] storing the
that corresponds to the sampled foreground masks, where RoI that corresponds to the sampled foreground masks, where
K = num_mask_samples_per_image. K = num_mask_samples_per_image.
foreground_classes: a tensor of shape of [batch_size, K] storing the foreground_classes: A `tf.Tensor` of shape of [batch_size, K] storing the
classes corresponding to the sampled foreground masks. classes corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of cropoped_foreground_masks: A `tf.Tensor` of shape of
[batch_size, K, mask_target_size, mask_target_size] storing the [batch_size, K, mask_target_size, mask_target_size] storing the
cropped foreground masks used for training. cropped foreground masks used for training.
""" """
......
...@@ -73,33 +73,33 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -73,33 +73,33 @@ class ResidualBlock(tf.keras.layers.Layer):
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**kwargs): **kwargs):
"""A residual block with BN after convolutions. """Initializes a residual block with BN after convolutions.
Args: Args:
filters: `int` number of filters for the first two convolutions. Note that filters: An `int` number of filters for the first two convolutions. Note
the third and final convolution will use 4 times as many filters. that the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately strides: An `int` block stride. If greater than 1, this block will
downsample the input. ultimately downsample the input.
use_projection: `bool` for whether this block should use a projection use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True` shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of for the first block of a block group, which may change the number of
filters and the resolution. filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer. se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: `bool` if True, apply the resnetd style modification to resnetd_shortcut: A `bool` if True, apply the resnetd style modification
the shortcut connection. Not implemented in residual blocks. to the shortcut connection. Not implemented in residual blocks.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for stochastic_depth_drop_rate: A `float` or None. if not None, drop rate for
the stochastic depth layer. the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: A `str` of kernel_initializer for convolutional
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. layers.
Default to None. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. Conv2D. Default to None.
Default to None. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
activation: `str` name of the activation function. Default to None.
use_sync_bn: if True, use synchronized batch normalization. activation: A `str` name of the activation function.
norm_momentum: `float` normalization omentum for the moving average. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_momentum: A `float` of normalization momentum for the moving average.
zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ResidualBlock, self).__init__(**kwargs) super(ResidualBlock, self).__init__(**kwargs)
...@@ -250,34 +250,34 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -250,34 +250,34 @@ class BottleneckBlock(tf.keras.layers.Layer):
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**kwargs): **kwargs):
"""A standard bottleneck block with BN after convolutions. """Initializes a standard bottleneck block with BN after convolutions.
Args: Args:
filters: `int` number of filters for the first two convolutions. Note that filters: An `int` number of filters for the first two convolutions. Note
the third and final convolution will use 4 times as many filters. that the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately strides: An `int` block stride. If greater than 1, this block will
downsample the input. ultimately downsample the input.
dilation_rate: `int` dilation_rate of convolutions. Default to 1. dilation_rate: An `int` dilation_rate of convolutions. Default to 1.
use_projection: `bool` for whether this block should use a projection use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True` shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of for the first block of a block group, which may change the number of
filters and the resolution. filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer. se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: `bool` if True, apply the resnetd style modification to resnetd_shortcut: A `bool`. If True, apply the resnetd style modification
the shortcut connection. to the shortcut connection.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for stochastic_depth_drop_rate: A `float` or None. If not None, drop rate for
the stochastic depth layer. the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: A `str` of kernel_initializer for convolutional
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. layers.
Default to None. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. Conv2D. Default to None.
Default to None. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
activation: `str` name of the activation function. Default to None.
use_sync_bn: if True, use synchronized batch normalization. activation: A `str` name of the activation function.
norm_momentum: `float` normalization omentum for the moving average. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_momentum: A `float` of normalization momentum for the moving average.
zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(BottleneckBlock, self).__init__(**kwargs) super(BottleneckBlock, self).__init__(**kwargs)
...@@ -472,47 +472,48 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -472,47 +472,48 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**kwargs): **kwargs):
"""An inverted bottleneck block with BN after convolutions. """Initializes an inverted bottleneck block with BN after convolutions.
Args: Args:
in_filters: `int` number of filters of the input tensor. in_filters: An `int` number of filters of the input tensor.
out_filters: `int` number of filters of the output tensor. out_filters: An `int` number of filters of the output tensor.
expand_ratio: `int` expand_ratio for an inverted bottleneck block. expand_ratio: An `int` of expand_ratio for an inverted bottleneck block.
strides: `int` block stride. If greater than 1, this block will ultimately strides: An `int` block stride. If greater than 1, this block will
downsample the input. ultimately downsample the input.
kernel_size: `int` kernel_size of the depthwise conv layer. kernel_size: An `int` kernel_size of the depthwise conv layer.
se_ratio: `float` or None. If not None, se ratio for the squeeze and se_ratio: A `float` or None. If not None, se ratio for the squeeze and
excitation layer. excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for stochastic_depth_drop_rate: A `float` or None. if not None, drop rate for
the stochastic depth layer. the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: A `str` of kernel_initializer for convolutional
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. layers.
Default to None. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None. Default to None.
activation: `str` name of the activation function. activation: A `str` name of the activation function.
se_inner_activation: Squeeze excitation inner activation. se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: Squeeze excitation gating activation. se_gating_activation: A `str` name of squeeze-excitation gating
expand_se_in_filters: Whether or not to expand in_filter in squeeze and activation.
excitation layer. expand_se_in_filters: A `bool` of whether or not to expand in_filter in
depthwise_activation: `str` name of the activation function for depthwise squeeze and excitation layer.
only. depthwise_activation: A `str` name of the activation function for
use_sync_bn: if True, use synchronized batch normalization. depthwise only.
dilation_rate: `int` an integer specifying the dilation rate to use for. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
divisible_by: `int` ensures all inner dimensions are divisible by this dilation_rate: An `int` that specifies the dilation rate to use for.
number. divisible_by: An `int` that ensures all inner dimensions are divisible by
dilated convolution. Can be a single integer to specify the same value for this number.
all spatial dimensions. dilated convolution: An `int` to specify the same value for all spatial
regularize_depthwise: `bool` whether or not apply regularization on dimensions.
regularize_depthwise: A `bool` of whether or not apply regularization on
depthwise. depthwise.
use_depthwise: `bool` whether to uses fused convolutions instead of use_depthwise: A `bool` of whether to uses fused convolutions instead of
depthwise. depthwise.
use_residual: `bool`whether to include residual connection between input use_residual: A `bool` of whether to include residual connection between
and output. input and output.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: A `float` added to variance to avoid dividing by zero.
zero. **kwargs: Additional keyword arguments to be passed.
**kwargs: keyword arguments to be passed.
""" """
super(InvertedBottleneckBlock, self).__init__(**kwargs) super(InvertedBottleneckBlock, self).__init__(**kwargs)
...@@ -702,10 +703,12 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -702,10 +703,12 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class ResidualInner(tf.keras.layers.Layer): class ResidualInner(tf.keras.layers.Layer):
"""Single inner block of a residual. """Creates a single inner block of a residual.
This corresponds to `F`/`G` functions in the RevNet paper: This corresponds to `F`/`G` functions in the RevNet paper:
https://arxiv.org/pdf/1707.04585.pdf Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse.
The Reversible Residual Network: Backpropagation Without Storing Activations.
(https://arxiv.org/pdf/1707.04585.pdf)
""" """
def __init__( def __init__(
...@@ -721,22 +724,21 @@ class ResidualInner(tf.keras.layers.Layer): ...@@ -721,22 +724,21 @@ class ResidualInner(tf.keras.layers.Layer):
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
batch_norm_first: bool = True, batch_norm_first: bool = True,
**kwargs): **kwargs):
"""ResidualInner Initialization. """Initializes a ResidualInner.
Args: Args:
filters: `int` output filter size. filters: An `int` of output filter size.
strides: `int` stride size for convolution for the residual block. strides: An `int` of stride size for convolution for the residual block.
kernel_initializer: `str` or `tf.keras.initializers.Initializer` instance kernel_initializer: A `str` or `tf.keras.initializers.Initializer`
for convolutional layers. instance for convolutional layers.
kernel_regularizer: `tf.keras.regularizers.Regularizer` for Conv2D. kernel_regularizer: A `tf.keras.regularizers.Regularizer` for Conv2D.
activation: `str` or `callable` instance of the activation function. activation: A `str` or `callable` instance of the activation function.
use_sync_bn: `bool` if True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: A `float` added to variance to avoid dividing by zero.
zero. batch_norm_first: A `bool` of whether to apply activation and batch norm
batch_norm_first: `bool` whether to apply activation and batch norm
before conv. before conv.
**kwargs: additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ResidualInner, self).__init__(**kwargs) super(ResidualInner, self).__init__(**kwargs)
...@@ -824,10 +826,12 @@ class ResidualInner(tf.keras.layers.Layer): ...@@ -824,10 +826,12 @@ class ResidualInner(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class BottleneckResidualInner(tf.keras.layers.Layer): class BottleneckResidualInner(tf.keras.layers.Layer):
"""Single inner block of a bottleneck residual. """Creates a single inner block of a bottleneck.
This corresponds to `F`/`G` functions in the RevNet paper: This corresponds to `F`/`G` functions in the RevNet paper:
https://arxiv.org/pdf/1707.04585.pdf Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse.
The Reversible Residual Network: Backpropagation Without Storing Activations.
(https://arxiv.org/pdf/1707.04585.pdf)
""" """
def __init__( def __init__(
...@@ -843,24 +847,23 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -843,24 +847,23 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
batch_norm_first: bool = True, batch_norm_first: bool = True,
**kwargs): **kwargs):
"""BottleneckResidualInner Initialization. """Initializes a BottleneckResidualInner.
Args: Args:
filters: `int` number of filters for first 2 convolutions. Last filters: An `int` number of filters for first 2 convolutions. Last Last,
Last, and thus the number of output channels from the bottlneck and thus the number of output channels from the bottlneck block is
block is `4*filters` `4*filters`
strides: `int` stride size for convolution for the residual block. strides: An `int` of stride size for convolution for the residual block.
kernel_initializer: `str` or `tf.keras.initializers.Initializer` instance kernel_initializer: A `str` or `tf.keras.initializers.Initializer`
for convolutional layers. instance for convolutional layers.
kernel_regularizer: `tf.keras.regularizers.Regularizer` for Conv2D. kernel_regularizer: A `tf.keras.regularizers.Regularizer` for Conv2D.
activation: `str` or `callable` instance of the activation function. activation: A `str` or `callable` instance of the activation function.
use_sync_bn: `bool` if True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: A `float` added to variance to avoid dividing by zero.
zero. batch_norm_first: A `bool` of whether to apply activation and batch norm
batch_norm_first: `bool` whether to apply activation and batch norm
before conv. before conv.
**kwargs: additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(BottleneckResidualInner, self).__init__(**kwargs) super(BottleneckResidualInner, self).__init__(**kwargs)
...@@ -962,7 +965,7 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -962,7 +965,7 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class ReversibleLayer(tf.keras.layers.Layer): class ReversibleLayer(tf.keras.layers.Layer):
"""A reversible layer. """Creates a reversible layer.
Computes y1 = x1 + f(x2), y2 = x2 + g(y1), where f and g can be arbitrary Computes y1 = x1 + f(x2), y2 = x2 + g(y1), where f and g can be arbitrary
layers that are stateless, which in this case are `ResidualInner` layers. layers that are stateless, which in this case are `ResidualInner` layers.
...@@ -973,20 +976,21 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -973,20 +976,21 @@ class ReversibleLayer(tf.keras.layers.Layer):
g: tf.keras.layers.Layer, g: tf.keras.layers.Layer,
manual_grads: bool = True, manual_grads: bool = True,
**kwargs): **kwargs):
"""ReversibleLayer Initialization. """Initializes a ReversibleLayer.
Args: Args:
f: `tf.keras.layers.Layer` f inner block referred to in paper. Each f: A `tf.keras.layers.Layer` instance of `f` inner block referred to in
reversible layer consists of two inner functions. For example, in RevNet paper. Each reversible layer consists of two inner functions. For
the reversible residual consists of two f/g inner (bottleneck) residual example, in RevNet the reversible residual consists of two f/g inner
functions. Where the input to the reversible layer is x, the input gets (bottleneck) residual functions. Where the input to the reversible layer
partitioned in the channel dimension and the forward pass follows (eq8): is x, the input gets partitioned in the channel dimension and the
x = [x1; x2], z1 = x1 + f(x2), y2 = x2 + g(z1), y1 = stop_gradient(z1). forward pass follows (eq8): x = [x1; x2], z1 = x1 + f(x2), y2 = x2 +
g: `tf.keras.layers.Layer` g inner block referred to in paper. Detailed g(z1), y1 = stop_gradient(z1).
explanation same as above as `f` arg. g: A `tf.keras.layers.Layer` instance of `g` inner block referred to in
manual_grads: `bool` [Testing Only] whether to manually take gradients paper. Detailed explanation same as above as `f` arg.
as in Algorithm 1 or defer to autograd. manual_grads: A `bool` [Testing Only] of whether to manually take
**kwargs: additional keyword arguments to be passed. gradients as in Algorithm 1 or defer to autograd.
**kwargs: Additional keyword arguments to be passed.
""" """
super(ReversibleLayer, self).__init__(**kwargs) super(ReversibleLayer, self).__init__(**kwargs)
...@@ -1030,16 +1034,19 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1030,16 +1034,19 @@ class ReversibleLayer(tf.keras.layers.Layer):
x: tf.Tensor x: tf.Tensor
) -> Tuple[tf.Tensor, Callable[[Any], Tuple[List[tf.Tensor], ) -> Tuple[tf.Tensor, Callable[[Any], Tuple[List[tf.Tensor],
List[tf.Tensor]]]]: List[tf.Tensor]]]]:
"""Implements Algorithm 1 in RevNet paper. """Implements Algorithm 1 in the RevNet paper.
Paper: https://arxiv.org/pdf/1707.04585.pdf Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse.
The Reversible Residual Network: Backpropagation Without Storing
Activations.
(https://arxiv.org/pdf/1707.04585.pdf)
Args: Args:
x: input tensor. x: An input `tf.Tensor.
Returns: Returns:
y: the output [y1; y2] in algorithm 1. y: The output [y1; y2] in Algorithm 1.
grad_fn: callable function that computes the gradients. grad_fn: A callable function that computes the gradients.
""" """
with tf.GradientTape() as fwdtape: with tf.GradientTape() as fwdtape:
fwdtape.watch(x) fwdtape.watch(x)
...@@ -1135,7 +1142,7 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1135,7 +1142,7 @@ class ReversibleLayer(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
"""An depthwise separable convolution block with batch normalization.""" """Creates an depthwise separable convolution block with batch normalization."""
def __init__( def __init__(
self, self,
...@@ -1151,29 +1158,29 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1151,29 +1158,29 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
**kwargs): **kwargs):
"""An convolution block with batch normalization. """Initializes a convolution block with batch normalization.
Args: Args:
filters: `int` number of filters for the first two convolutions. Note that filters: An `int` number of filters for the first two convolutions. Note
the third and final convolution will use 4 times as many filters. that the third and final convolution will use 4 times as many filters.
kernel_size: `int` an integer specifying the height and width of the kernel_size: An `int` that specifies the height and width of the 2D
2D convolution window. convolution window.
strides: `int` block stride. If greater than 1, this block will ultimately strides: An `int` of block stride. If greater than 1, this block will
downsample the input. ultimately downsample the input.
regularize_depthwise: if Ture, apply regularization on depthwise. regularize_depthwise: A `bool`. If Ture, apply regularization on
activation: `str` name of the activation function. depthwise.
kernel_initializer: kernel_initializer for convolutional layers. activation: A `str` name of the activation function.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_initializer: A `str` of kernel_initializer for convolutional
Default to None. layers.
dilation_rate: an integer or tuple/list of 2 integers, specifying kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
the dilation rate to use for dilated convolution. Conv2D. Default to None.
Can be a single integer to specify the same value for dilation_rate: An `int` or tuple/list of 2 `int`, specifying the dilation
all spatial dimensions. rate to use for dilated convolution. Can be a single integer to specify
use_sync_bn: if True, use synchronized batch normalization. the same value for all spatial dimensions.
norm_momentum: `float` normalization omentum for the moving average. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_momentum: A `float` of normalization momentum for the moving average.
zero. norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(DepthwiseSeparableConvBlock, self).__init__(**kwargs) super(DepthwiseSeparableConvBlock, self).__init__(**kwargs)
self._filters = filters self._filters = filters
......
...@@ -21,14 +21,21 @@ from official.modeling import tf_utils ...@@ -21,14 +21,21 @@ from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SelfGating(tf.keras.layers.Layer): class SelfGating(tf.keras.layers.Layer):
"""Feature gating as used in S3D-G (https://arxiv.org/pdf/1712.04851.pdf).""" """Feature gating as used in S3D-G.
This implements the S3D-G network from:
Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu, Kevin Murphy.
Rethinking Spatiotemporal Feature Learning: Speed-Accuracy Trade-offs in Video
Classification.
(https://arxiv.org/pdf/1712.04851.pdf)
"""
def __init__(self, filters, **kwargs): def __init__(self, filters, **kwargs):
"""Constructor. """Initializes a self-gating layer.
Args: Args:
filters: `int` number of filters for the convolutional layer. filters: An `int` number of filters for the convolutional layer.
**kwargs: keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(SelfGating, self).__init__(**kwargs) super(SelfGating, self).__init__(**kwargs)
self._filters = filters self._filters = filters
...@@ -61,7 +68,7 @@ class SelfGating(tf.keras.layers.Layer): ...@@ -61,7 +68,7 @@ class SelfGating(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock3D(tf.keras.layers.Layer): class BottleneckBlock3D(tf.keras.layers.Layer):
"""A 3D bottleneck block.""" """Creates a 3D bottleneck block."""
def __init__(self, def __init__(self,
filters, filters,
...@@ -77,28 +84,29 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -77,28 +84,29 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
**kwargs): **kwargs):
"""A 3D bottleneck block with BN after convolutions. """Initializes a 3D bottleneck block with BN after convolutions.
Args: Args:
filters: `int` number of filters for the first two convolutions. Note that filters: An `int` number of filters for the first two convolutions. Note
the third and final convolution will use 4 times as many filters. that the third and final convolution will use 4 times as many filters.
temporal_kernel_size: `int` kernel size for the temporal convolutional temporal_kernel_size: An `int` of kernel size for the temporal
layer. convolutional layer.
temporal_strides: `int` temporal stride for the temporal convolutional temporal_strides: An `int` of ftemporal stride for the temporal
convolutional layer.
spatial_strides: An `int` of spatial stride for the spatial convolutional
layer. layer.
spatial_strides: `int` spatial stride for the spatial convolutional layer. use_self_gating: A `bool` of whether to apply self-gating module or not.
use_self_gating: `bool` apply self-gating module or not. kernel_initializer: A `str` of kernel_initializer for convolutional
kernel_initializer: kernel_initializer for convolutional layers. layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Default to None. Conv2D. Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None. Default to None.
activation: `str` name of the activation function. activation: A `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: A `float` added to variance to avoid dividing by zero.
zero. **kwargs: Additional keyword arguments to be passed.
**kwargs: keyword arguments to be passed.
""" """
super(BottleneckBlock3D, self).__init__(**kwargs) super(BottleneckBlock3D, self).__init__(**kwargs)
......
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
# ============================================================================== # ==============================================================================
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
from typing import Optional from typing import Callable, Dict, List, Optional, Tuple, Union
# Import libraries
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -24,6 +22,11 @@ import tensorflow as tf ...@@ -24,6 +22,11 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
# Type annotations.
States = Dict[str, tf.Tensor]
Activation = Union[str, Callable]
def make_divisible(value: float, def make_divisible(value: float,
divisor: int, divisor: int,
min_value: Optional[float] = None min_value: Optional[float] = None
...@@ -31,12 +34,12 @@ def make_divisible(value: float, ...@@ -31,12 +34,12 @@ def make_divisible(value: float,
"""This is to ensure that all layers have channels that are divisible by 8. """This is to ensure that all layers have channels that are divisible by 8.
Args: Args:
value: `float` original value. value: A `float` of original value.
divisor: `int` the divisor that need to be checked upon. divisor: An `int` off the divisor that need to be checked upon.
min_value: `float` minimum value threshold. min_value: A `float` of minimum value threshold.
Returns: Returns:
The adjusted value in `int` that divisible against divisor. The adjusted value in `int` that is divisible against divisor.
""" """
if min_value is None: if min_value is None:
min_value = divisor min_value = divisor
...@@ -52,7 +55,7 @@ def round_filters(filters: int, ...@@ -52,7 +55,7 @@ def round_filters(filters: int,
divisor: int = 8, divisor: int = 8,
min_depth: Optional[int] = None, min_depth: Optional[int] = None,
skip: bool = False): skip: bool = False):
"""Round number of filters based on width multiplier.""" """Rounds number of filters based on width multiplier."""
orig_f = filters orig_f = filters
if skip or not multiplier: if skip or not multiplier:
return filters return filters
...@@ -67,7 +70,7 @@ def round_filters(filters: int, ...@@ -67,7 +70,7 @@ def round_filters(filters: int,
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer): class SqueezeExcitation(tf.keras.layers.Layer):
"""Squeeze and excitation layer.""" """Creates a squeeze and excitation layer."""
def __init__(self, def __init__(self,
in_filters, in_filters,
...@@ -81,25 +84,26 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -81,25 +84,26 @@ class SqueezeExcitation(tf.keras.layers.Layer):
activation='relu', activation='relu',
gating_activation='sigmoid', gating_activation='sigmoid',
**kwargs): **kwargs):
"""Implementation for squeeze and excitation. """Initializes a squeeze and excitation layer.
Args: Args:
in_filters: `int` number of filters of the input tensor. in_filters: An `int` number of filters of the input tensor.
out_filters: `int` number of filters of the output tensor. out_filters: An `int` number of filters of the output tensor.
se_ratio: `float` or None. If not None, se ratio for the squeeze and se_ratio: A `float` or None. If not None, se ratio for the squeeze and
excitation layer. excitation layer.
divisible_by: `int` ensures all inner dimensions are divisible by this divisible_by: An `int` that ensures all inner dimensions are divisible by
number. this number.
use_3d_input: `bool` 2D image or 3D input type. use_3d_input: A `bool` of whether input is 2D or 3D image.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: A `str` of kernel_initializer for convolutional
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2d.
Default to None. Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d. activation: A `str` name of the activation function.
Default to None. gating_activation: A `str` name of the activation function for final
activation: `str` name of the activation function. gating function.
gating_activation: `str` name of the activation function for final gating **kwargs: Additional keyword arguments to be passed.
function.
**kwargs: keyword arguments to be passed.
""" """
super(SqueezeExcitation, self).__init__(**kwargs) super(SqueezeExcitation, self).__init__(**kwargs)
...@@ -180,9 +184,9 @@ def get_stochastic_depth_rate(init_rate, i, n): ...@@ -180,9 +184,9 @@ def get_stochastic_depth_rate(init_rate, i, n):
"""Get drop connect rate for the ith block. """Get drop connect rate for the ith block.
Args: Args:
init_rate: `float` initial drop rate. init_rate: A `float` of initial drop rate.
i: `int` order of the current block. i: An `int` of order of the current block.
n: `int` total number of blocks. n: An `int` total number of blocks.
Returns: Returns:
Drop rate of the ith block. Drop rate of the ith block.
...@@ -198,17 +202,17 @@ def get_stochastic_depth_rate(init_rate, i, n): ...@@ -198,17 +202,17 @@ def get_stochastic_depth_rate(init_rate, i, n):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class StochasticDepth(tf.keras.layers.Layer): class StochasticDepth(tf.keras.layers.Layer):
"""Stochastic depth layer.""" """Creates a stochastic depth layer."""
def __init__(self, stochastic_depth_drop_rate, **kwargs): def __init__(self, stochastic_depth_drop_rate, **kwargs):
"""Initialize stochastic depth. """Initializes a stochastic depth layer.
Args: Args:
stochastic_depth_drop_rate: `float` drop rate. stochastic_depth_drop_rate: A `float` of drop rate.
**kwargs: keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
Returns: Returns:
A output tensor, which should have the same shape as input. A output `tf.Tensor` of which should have the same shape as input.
""" """
super(StochasticDepth, self).__init__(**kwargs) super(StochasticDepth, self).__init__(**kwargs)
self._drop_rate = stochastic_depth_drop_rate self._drop_rate = stochastic_depth_drop_rate
...@@ -236,15 +240,15 @@ class StochasticDepth(tf.keras.layers.Layer): ...@@ -236,15 +240,15 @@ class StochasticDepth(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
def pyramid_feature_fusion(inputs, target_level): def pyramid_feature_fusion(inputs, target_level):
"""Fuse all feature maps in the feature pyramid at the target level. """Fuses all feature maps in the feature pyramid at the target level.
Args: Args:
inputs: a dictionary containing the feature pyramid. The size of the input inputs: A dictionary containing the feature pyramid. The size of the input
tensor needs to be fixed. tensor needs to be fixed.
target_level: `int` the target feature level for feature fusion. target_level: An `int` of the target feature level for feature fusion.
Returns: Returns:
A float Tensor of shape [batch_size, feature_height, feature_width, A `float` `tf.Tensor` of shape [batch_size, feature_height, feature_width,
feature_channel]. feature_channel].
""" """
# Convert keys to int. # Convert keys to int.
...@@ -270,3 +274,614 @@ def pyramid_feature_fusion(inputs, target_level): ...@@ -270,3 +274,614 @@ def pyramid_feature_fusion(inputs, target_level):
resampled_feats.append(feat) resampled_feats.append(feat)
return tf.math.add_n(resampled_feats) return tf.math.add_n(resampled_feats)
@tf.keras.utils.register_keras_serializable(package='Vision')
class Scale(tf.keras.layers.Layer):
"""Scales the input by a trainable scalar weight.
This is useful for applying ReZero to layers, which improves convergence
speed. This implements the paper:
Thomas Bachlechner, Bodhisattwa Prasad Majumder, Huanru Henry Mao,
Garrison W. Cottrell, Julian McAuley.
ReZero is All You Need: Fast Convergence at Large Depth.
(https://arxiv.org/pdf/2003.04887.pdf).
"""
def __init__(
self,
initializer: tf.keras.initializers.Initializer = 'ones',
regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a scale layer.
Args:
initializer: A `str` of initializer for the scalar weight.
regularizer: A `tf.keras.regularizers.Regularizer` for the scalar weight.
**kwargs: Additional keyword arguments to be passed to this layer.
Returns:
An `tf.Tensor` of which should have the same shape as input.
"""
super(Scale, self).__init__(**kwargs)
self._initializer = initializer
self._regularizer = regularizer
self._scale = self.add_weight(
name='scale',
shape=[],
dtype=self.dtype,
initializer=self._initializer,
regularizer=self._regularizer,
trainable=True)
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'initializer': self._initializer,
'regularizer': self._regularizer,
}
base_config = super(Scale, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Calls the layer with the given inputs."""
scale = tf.cast(self._scale, inputs.dtype)
return scale * inputs
@tf.keras.utils.register_keras_serializable(package='Vision')
class TemporalSoftmaxPool(tf.keras.layers.Layer):
"""Creates a network layer corresponding to temporal softmax pooling.
This is useful for multi-class logits (used in e.g., Charades). Modified from
AssembleNet Charades evaluation from:
Michael S. Ryoo, AJ Piergiovanni, Mingxing Tan, Anelia Angelova.
AssembleNet: Searching for Multi-Stream Neural Connectivity in Video
Architectures.
(https://arxiv.org/pdf/1905.13209.pdf).
"""
def call(self, inputs):
"""Calls the layer with the given inputs."""
assert inputs.shape.rank in (3, 4, 5)
frames = tf.shape(inputs)[1]
pre_logits = inputs / tf.sqrt(tf.cast(frames, inputs.dtype))
activations = tf.nn.softmax(pre_logits, axis=1)
outputs = inputs * activations
return outputs
@tf.keras.utils.register_keras_serializable(package='Vision')
class PositionalEncoding(tf.keras.layers.Layer):
"""Creates a network layer that adds a sinusoidal positional encoding.
Positional encoding is incremented across frames, and is added to the input.
The positional encoding is first weighted at 0 so that the network can choose
to ignore it. This implements:
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin.
Attention Is All You Need.
(https://arxiv.org/pdf/1706.03762.pdf).
"""
def __init__(self,
initializer: tf.keras.initializers.Initializer = 'zeros',
cache_encoding: bool = False,
**kwargs):
"""Initializes positional encoding.
Args:
initializer: A `str` of initializer for weighting the positional encoding.
cache_encoding: A `bool`. If True, cache the positional encoding tensor
after calling build. Otherwise, rebuild the tensor for every call.
Setting this to False can be useful when we want to input a variable
number of frames, so the positional encoding tensor can change shape.
**kwargs: Additional keyword arguments to be passed to this layer.
Returns:
A `tf.Tensor` of which should have the same shape as input.
"""
super(PositionalEncoding, self).__init__(**kwargs)
self._initializer = initializer
self._cache_encoding = cache_encoding
self._pos_encoding = None
self._rezero = Scale(initializer=initializer, name='rezero')
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'initializer': self._initializer,
'cache_encoding': self._cache_encoding,
}
base_config = super(PositionalEncoding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _positional_encoding(self,
num_positions: int,
hidden_size: int,
dtype: tf.DType = tf.float32):
"""Creates a sequence of sinusoidal positional encoding vectors.
Args:
num_positions: An `int` of number of positions (frames).
hidden_size: An `int` of number of channels used for the hidden vectors.
dtype: The dtype of the output tensor.
Returns:
The positional encoding tensor with shape [num_positions, hidden_size].
"""
# Calling `tf.range` with `dtype=tf.bfloat16` results in an error,
# so we cast afterward.
positions = tf.cast(tf.range(num_positions)[:, tf.newaxis], dtype)
idx = tf.range(hidden_size)[tf.newaxis, :]
power = tf.cast(2 * (idx // 2), dtype)
power /= tf.cast(hidden_size, dtype)
angles = 1. / tf.math.pow(10_000., power)
radians = positions * angles
sin = tf.math.sin(radians[:, 0::2])
cos = tf.math.cos(radians[:, 1::2])
pos_encoding = tf.concat([sin, cos], axis=-1)
return pos_encoding
def _get_pos_encoding(self, input_shape):
"""Calculates the positional encoding from the input shape."""
frames = input_shape[1]
channels = input_shape[-1]
pos_encoding = self._positional_encoding(frames, channels, dtype=self.dtype)
pos_encoding = tf.reshape(pos_encoding, [1, frames, 1, 1, channels])
return pos_encoding
def build(self, input_shape):
"""Builds the layer with the given input shape.
Args:
input_shape: The input shape.
Raises:
ValueError: If using 'channels_first' data format.
"""
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
if self._cache_encoding:
self._pos_encoding = self._get_pos_encoding(input_shape)
super(PositionalEncoding, self).build(input_shape)
def call(self, inputs):
"""Calls the layer with the given inputs."""
if self._cache_encoding:
pos_encoding = self._pos_encoding
else:
pos_encoding = self._get_pos_encoding(tf.shape(inputs))
pos_encoding = tf.cast(pos_encoding, inputs.dtype)
pos_encoding = tf.stop_gradient(pos_encoding)
pos_encoding = self._rezero(pos_encoding)
return inputs + pos_encoding
@tf.keras.utils.register_keras_serializable(package='Vision')
class GlobalAveragePool3D(tf.keras.layers.Layer):
"""Creates a global average pooling layer with causal mode.
Implements causal mode, which runs a cumulative sum (with `tf.cumsum`) across
frames in the time dimension, allowing the use of a stream buffer. Sums any
valid input state with the current input to allow state to accumulate over
several iterations.
"""
def __init__(self,
keepdims: bool = False,
causal: bool = False,
**kwargs):
"""Initializes a global average pool layer.
Args:
keepdims: A `bool`. If True, keep the averaged dimensions.
causal: A `bool` of whether to run in causal mode with a cumulative sum
across frames.
**kwargs: Additional keyword arguments to be passed to this layer.
Returns:
An output `tf.Tensor`.
"""
super(GlobalAveragePool3D, self).__init__(**kwargs)
self._keepdims = keepdims
self._causal = causal
self._frame_count = None
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'keepdims': self._keepdims,
'causal': self._causal,
}
base_config = super(GlobalAveragePool3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
# Here we define strings that will uniquely reference the buffer states
# in the TF graph. These will be used for passing in a mapping of states
# for streaming mode. To do this, we can use a name scope.
with tf.name_scope('buffer') as state_name:
self._state_name = state_name
self._frame_count_name = state_name + '_frame_count'
super(GlobalAveragePool3D, self).build(input_shape)
def call(self,
inputs: tf.Tensor,
states: Optional[States] = None,
output_states: bool = True
) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
"""Calls the layer with the given inputs.
Args:
inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise.
Returns:
An output `tf.Tensor` (and optionally the states if `output_states=True`).
If `causal=True`, the output tensor will have shape
`[batch_size, num_frames, 1, 1, channels]` if `keepdims=True`. We keep
the frame dimension in this case to simulate a cumulative global average
as if we are inputting one frame at a time. If `causal=False`, the output
is equivalent to `tf.keras.layers.GlobalAveragePooling3D` with shape
`[batch_size, 1, 1, 1, channels]` if `keepdims=True` (plus the optional
buffer stored in `states`).
Raises:
ValueError: If using 'channels_first' data format.
"""
states = dict(states) if states is not None else {}
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
# Shape: [batch_size, 1, 1, 1, channels]
buffer = states.get(self._state_name, None)
if buffer is None:
buffer = tf.zeros_like(inputs[:, :1, :1, :1], dtype=inputs.dtype)
states[self._state_name] = buffer
# Keep a count of frames encountered across input iterations in
# num_frames to be able to accurately take a cumulative average across
# all frames when running in streaming mode
num_frames = tf.shape(inputs)[1]
frame_count = states.get(self._frame_count_name, 0)
states[self._frame_count_name] = frame_count + num_frames
if self._causal:
# Take a mean of spatial dimensions to make computation more efficient.
x = tf.reduce_mean(inputs, axis=[2, 3], keepdims=True)
x = tf.cumsum(x, axis=1)
x = x + buffer
# The last frame will be the value of the next state
# Shape: [batch_size, 1, 1, 1, channels]
states[self._state_name] = x[:, -1:]
# In causal mode, the divisor increments by 1 for every frame to
# calculate cumulative averages instead of one global average
mean_divisors = tf.range(num_frames) + frame_count + 1
mean_divisors = tf.reshape(mean_divisors, [1, num_frames, 1, 1, 1])
mean_divisors = tf.cast(mean_divisors, x.dtype)
# Shape: [batch_size, num_frames, 1, 1, channels]
x = x / mean_divisors
else:
# In non-causal mode, we (optionally) sum across frames to take a
# cumulative average across input iterations rather than individual
# frames. If no buffer state is passed, this essentially becomes
# regular global average pooling.
# Shape: [batch_size, 1, 1, 1, channels]
x = tf.reduce_sum(inputs, axis=(1, 2, 3), keepdims=True)
x = x / tf.cast(inputs.shape[2] * inputs.shape[3], x.dtype)
x = x + buffer
# Shape: [batch_size, 1, 1, 1, channels]
states[self._state_name] = x
x = x / tf.cast(frame_count + num_frames, x.dtype)
if not self._keepdims:
x = tf.squeeze(x, axis=(1, 2, 3))
return (x, states) if output_states else x
@tf.keras.utils.register_keras_serializable(package='Vision')
class SpatialAveragePool3D(tf.keras.layers.Layer):
"""Creates a global average pooling layer pooling across spatial dimentions."""
def __init__(self, keepdims: bool = False, **kwargs):
"""Initializes a global average pool layer.
Args:
keepdims: A `bool`. If True, keep the averaged dimensions.
**kwargs: Additional keyword arguments to be passed to this layer.
Returns:
An output `tf.Tensor`.
"""
super(SpatialAveragePool3D, self).__init__(**kwargs)
self._keepdims = keepdims
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'keepdims': self._keepdims,
}
base_config = super(SpatialAveragePool3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
super(SpatialAveragePool3D, self).build(input_shape)
def call(self, inputs):
"""Calls the layer with the given inputs."""
if inputs.shape.rank != 5:
raise ValueError(
'Input should have rank {}, got {}'.format(5, inputs.shape.rank))
return tf.reduce_mean(inputs, axis=(2, 3), keepdims=self._keepdims)
class CausalConvMixin:
"""Mixin class to implement CausalConv for `tf.keras.layers.Conv` layers."""
@property
def use_buffered_input(self) -> bool:
return self._use_buffered_input
@use_buffered_input.setter
def use_buffered_input(self, variable: bool):
self._use_buffered_input = variable
def _compute_buffered_causal_padding(self,
inputs: Optional[tf.Tensor] = None,
use_buffered_input: bool = False,
time_axis: int = 1) -> List[List[int]]:
"""Calculates padding for 'causal' option for conv layers.
Args:
inputs: An optional input `tf.Tensor` to be padded.
use_buffered_input: A `bool`. If True, use 'valid' padding along the time
dimension. This should be set when applying the stream buffer.
time_axis: An `int` of the axis of the time dimension.
Returns:
A list of paddings for `tf.pad`.
"""
del inputs
if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.')
kernel_size_effective = [
(self.kernel_size[i] +
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank)
]
pad_total = [kernel_size_effective[i] - 1 for i in range(self.rank)]
pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
padding = [[0, 0]] + padding + [[0, 0]]
if use_buffered_input:
padding[time_axis] = [0, 0]
else:
padding[time_axis] = [padding[time_axis][0] + padding[time_axis][1], 0]
return padding
def _causal_validate_init(self):
"""Validates the Conv layer initial configuration."""
# Overriding this method is meant to circumvent unnecessary errors when
# using causal padding.
if (self.filters is not None
and self.filters % self.groups != 0):
raise ValueError(
'The number of filters must be evenly divisible by the number of '
'groups. Received: groups={}, filters={}'.format(
self.groups, self.filters))
if not all(self.kernel_size):
raise ValueError('The argument `kernel_size` cannot contain 0(s). '
'Received: %s' % (self.kernel_size,))
def _buffered_spatial_output_shape(self, spatial_output_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
# When buffer padding, use 'valid' padding across time. The output shape
# across time should be the input shape minus any padding, assuming
# the stride across time is 1.
if self._use_buffered_input:
padding = self._compute_buffered_causal_padding(use_buffered_input=False)
spatial_output_shape[0] -= sum(padding[1])
return spatial_output_shape
@tf.keras.utils.register_keras_serializable(package='Vision')
class Conv2D(tf.keras.layers.Conv2D, CausalConvMixin):
"""Conv2D layer supporting CausalConv.
Supports `padding='causal'` option (like in `tf.keras.layers.Conv1D`),
which applies causal padding to the temporal dimension, and same padding in
the spatial dimensions.
"""
def __init__(self, *args, use_buffered_input=False, **kwargs):
"""Initializes conv2d.
Args:
*args: Arguments to be passed.
use_buffered_input: A `bool`. If True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
**kwargs: Additional keyword arguments to be passed.
Returns:
An output `tf.Tensor` of the Conv2D operation.
"""
super(Conv2D, self).__init__(*args, **kwargs)
self._use_buffered_input = use_buffered_input
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'use_buffered_input': self._use_buffered_input,
}
base_config = super(Conv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs."""
return self._compute_buffered_causal_padding(
inputs, use_buffered_input=self._use_buffered_input)
def _validate_init(self):
"""Validates the Conv layer initial configuration."""
self._causal_validate_init()
def _spatial_output_shape(self, spatial_input_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
shape = super(Conv2D, self)._spatial_output_shape(spatial_input_shape)
return self._buffered_spatial_output_shape(shape)
@tf.keras.utils.register_keras_serializable(package='Vision')
class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, CausalConvMixin):
"""DepthwiseConv2D layer supporting CausalConv.
Supports `padding='causal'` option (like in `tf.keras.layers.Conv1D`),
which applies causal padding to the temporal dimension, and same padding in
the spatial dimensions.
"""
def __init__(self, *args, use_buffered_input=False, **kwargs):
"""Initializes depthwise conv2d.
Args:
*args: Arguments to be passed.
use_buffered_input: A `bool`. If True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
**kwargs: Additional keyword arguments to be passed.
Returns:
An output `tf.Tensor` of the DepthwiseConv2D operation.
"""
super(DepthwiseConv2D, self).__init__(*args, **kwargs)
self._use_buffered_input = use_buffered_input
# Causal padding is unsupported by default for DepthwiseConv2D,
# so we resort to valid padding internally. However, we handle
# causal padding as a special case with `self._is_causal`, which is
# defined by the super class.
if self.padding == 'causal':
self.padding = 'valid'
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'use_buffered_input': self._use_buffered_input,
}
base_config = super(DepthwiseConv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Calls the layer with the given inputs."""
if self._is_causal:
inputs = tf.pad(inputs, self._compute_causal_padding(inputs))
return super(DepthwiseConv2D, self).call(inputs)
def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs."""
return self._compute_buffered_causal_padding(
inputs, use_buffered_input=self._use_buffered_input)
def _validate_init(self):
"""Validates the Conv layer initial configuration."""
self._causal_validate_init()
def _spatial_output_shape(self, spatial_input_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
shape = super(DepthwiseConv2D, self)._spatial_output_shape(
spatial_input_shape)
return self._buffered_spatial_output_shape(shape)
@tf.keras.utils.register_keras_serializable(package='Vision')
class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin):
"""Conv3D layer supporting CausalConv.
Supports `padding='causal'` option (like in `tf.keras.layers.Conv1D`),
which applies causal padding to the temporal dimension, and same padding in
the spatial dimensions.
"""
def __init__(self, *args, use_buffered_input=False, **kwargs):
"""Initializes conv3d.
Args:
*args: Arguments to be passed.
use_buffered_input: A `bool`. If True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
**kwargs: Additional keyword arguments to be passed.
Returns:
An output `tf.Tensor` of the Conv3D operation.
"""
super(Conv3D, self).__init__(*args, **kwargs)
self._use_buffered_input = use_buffered_input
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
config = {
'use_buffered_input': self._use_buffered_input,
}
base_config = super(Conv3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
super(Conv3D, self).build(input_shape)
# TODO(b/177662019): tf.nn.conv3d with depthwise kernels on CPU
# in eager mode may produce incorrect output or cause a segfault.
# To avoid this issue, compile the op to TF graph using tf.function.
self._convolution_op = tf.function(
self._convolution_op, experimental_compile=True)
def _compute_causal_padding(self, inputs):
"""Computes causal padding dimensions for the given inputs."""
return self._compute_buffered_causal_padding(
inputs, use_buffered_input=self._use_buffered_input)
def _validate_init(self):
"""Validates the Conv layer initial configuration."""
self._causal_validate_init()
def _spatial_output_shape(self, spatial_input_shape: List[int]):
"""Computes the spatial output shape from the input shape."""
shape = super(Conv3D, self)._spatial_output_shape(spatial_input_shape)
return self._buffered_spatial_output_shape(shape)
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nn_layers."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_scale(self):
scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.))
output = scale(3.)
self.assertAllEqual(output, 30.)
def test_temporal_softmax_pool(self):
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
layer = nn_layers.TemporalSoftmaxPool()
output = layer(inputs)
self.assertAllClose(
output,
[[[[[0.10153633]]],
[[[0.33481020]]],
[[[0.82801306]]],
[[[1.82021690]]]]])
def test_positional_encoding(self):
pos_encoding = nn_layers.PositionalEncoding(
initializer='ones', cache_encoding=False)
pos_encoding_cached = nn_layers.PositionalEncoding(
initializer='ones', cache_encoding=True)
inputs = tf.ones([1, 4, 1, 1, 3])
outputs = pos_encoding(inputs)
outputs_cached = pos_encoding_cached(inputs)
expected = tf.constant(
[[[[[1.0000000, 1.0000000, 2.0000000]]],
[[[1.8414710, 1.0021545, 1.5403023]]],
[[[1.9092975, 1.0043088, 0.5838531]]],
[[[1.1411200, 1.0064633, 0.0100075]]]]])
self.assertEqual(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected)
self.assertEqual(outputs.shape, outputs_cached.shape)
self.assertAllClose(outputs, outputs_cached)
inputs = tf.ones([1, 5, 1, 1, 3])
_ = pos_encoding(inputs)
def test_positional_encoding_bfloat16(self):
pos_encoding = nn_layers.PositionalEncoding(initializer='ones')
inputs = tf.ones([1, 4, 1, 1, 3], dtype=tf.bfloat16)
outputs = pos_encoding(inputs)
expected = tf.constant(
[[[[[1.0000000, 1.0000000, 2.0000000]]],
[[[1.8414710, 1.0021545, 1.5403023]]],
[[[1.9092975, 1.0043088, 0.5838531]]],
[[[1.1411200, 1.0064633, 0.0100075]]]]])
self.assertEqual(outputs.shape, expected.shape)
self.assertAllClose(outputs, expected)
def test_global_average_pool_basic(self):
pool = nn_layers.GlobalAveragePool3D(keepdims=True)
inputs = tf.ones([1, 2, 3, 4, 1])
outputs = pool(inputs, output_states=False)
expected = tf.ones([1, 1, 1, 1, 1])
self.assertEqual(outputs.shape, expected.shape)
self.assertAllEqual(outputs, expected)
def test_global_average_pool_keras(self):
pool = nn_layers.GlobalAveragePool3D(keepdims=False)
keras_pool = tf.keras.layers.GlobalAveragePooling3D()
inputs = 10 * tf.random.normal([1, 2, 3, 4, 1])
outputs = pool(inputs, output_states=False)
keras_output = keras_pool(inputs)
self.assertAllEqual(outputs.shape, keras_output.shape)
self.assertAllClose(outputs, keras_output)
def test_stream_global_average_pool(self):
gap = nn_layers.GlobalAveragePool3D(keepdims=True, causal=False)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = None
for frame in frames:
predicted, states = gap(frame, states=states)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[2.5, 2.5, 2.5]]]]])
def test_causal_stream_global_average_pool(self):
gap = nn_layers.GlobalAveragePool3D(keepdims=True, causal=True)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 2, 3])
expected, _ = gap(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = gap(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(
predicted,
[[[[[1.0, 1.0, 1.0]]],
[[[1.5, 1.5, 1.5]]],
[[[2.0, 2.0, 2.0]]],
[[[2.5, 2.5, 2.5]]]]])
def test_spatial_average_pool(self):
pool = nn_layers.SpatialAveragePool3D(keepdims=True)
inputs = tf.range(64, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 4, 4, 1])
output = pool(inputs)
self.assertEqual(output.shape, [1, 4, 1, 1, 1])
self.assertAllClose(
output,
[[[[[8.50]]],
[[[24.5]]],
[[[40.5]]],
[[[56.5]]]]])
def test_conv2d_causal(self):
conv2d = nn_layers.Conv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 2),
padding='causal',
use_buffered_input=True,
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 4, 2, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv2d(padded_inputs)
expected = tf.constant(
[[[[6.0, 6.0, 6.0]],
[[12., 12., 12.]],
[[18., 18., 18.]],
[[18., 18., 18.]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv2d.use_buffered_input = False
predicted = conv2d(inputs)
self.assertFalse(conv2d.use_buffered_input)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_depthwise_conv2d_causal(self):
conv2d = nn_layers.DepthwiseConv2D(
kernel_size=(3, 3),
strides=(1, 1),
padding='causal',
use_buffered_input=True,
depthwise_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 2, 2, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv2d(padded_inputs)
expected = tf.constant(
[[[[2., 2., 2.],
[2., 2., 2.]],
[[4., 4., 4.],
[4., 4., 4.]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv2d.use_buffered_input = False
predicted = conv2d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_conv3d_causal(self):
conv3d = nn_layers.Conv3D(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=True,
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 2, 4, 4, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv3d(padded_inputs)
expected = tf.constant(
[[[[[12., 12., 12.],
[18., 18., 18.]],
[[18., 18., 18.],
[27., 27., 27.]]],
[[[24., 24., 24.],
[36., 36., 36.]],
[[36., 36., 36.],
[54., 54., 54.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv3d.use_buffered_input = False
predicted = conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_depthwise_conv3d_causal(self):
conv3d = nn_layers.Conv3D(
filters=3,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=True,
kernel_initializer='ones',
use_bias=False,
groups=3,
)
inputs = tf.ones([1, 2, 4, 4, 3])
paddings = [[0, 0], [2, 0], [0, 0], [0, 0], [0, 0]]
padded_inputs = tf.pad(inputs, paddings)
predicted = conv3d(padded_inputs)
expected = tf.constant(
[[[[[4.0, 4.0, 4.0],
[6.0, 6.0, 6.0]],
[[6.0, 6.0, 6.0],
[9.0, 9.0, 9.0]]],
[[[8.0, 8.0, 8.0],
[12., 12., 12.]],
[[12., 12., 12.],
[18., 18., 18.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
conv3d.use_buffered_input = False
predicted = conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
if __name__ == '__main__':
tf.test.main()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ROI align.""" """Contains definitions of ROI aligner."""
import tensorflow as tf import tensorflow as tf
...@@ -30,9 +30,9 @@ class MultilevelROIAligner(tf.keras.layers.Layer): ...@@ -30,9 +30,9 @@ class MultilevelROIAligner(tf.keras.layers.Layer):
"""Initializes a ROI aligner. """Initializes a ROI aligner.
Args: Args:
crop_size: int, the output size of the cropped features. crop_size: An `int` of the output size of the cropped features.
sample_offset: float in [0, 1], the subpixel sample offset. sample_offset: A `float` in [0, 1] of the subpixel sample offset.
**kwargs: other key word arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
'crop_size': crop_size, 'crop_size': crop_size,
...@@ -47,13 +47,13 @@ class MultilevelROIAligner(tf.keras.layers.Layer): ...@@ -47,13 +47,13 @@ class MultilevelROIAligner(tf.keras.layers.Layer):
features: A dictionary with key as pyramid level and value as features. features: A dictionary with key as pyramid level and value as features.
The features are in shape of The features are in shape of
[batch_size, height_l, width_l, num_filters]. [batch_size, height_l, width_l, num_filters].
boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row boxes: A 3-D `tf.Tensor` of shape [batch_size, num_boxes, 4]. Each row
represents a box with [y1, x1, y2, x2] in un-normalized coordinates. represents a box with [y1, x1, y2, x2] in un-normalized coordinates.
from grid point. from grid point.
training: bool, whether it is in training mode. training: A `bool` of whether it is in training mode.
Returns: Returns:
roi_features: A 5-D tensor representing feature crop of shape A 5-D `tf.Tensor` representing feature crop of shape
[batch_size, num_boxes, crop_size, crop_size, num_filters]. [batch_size, num_boxes, crop_size, crop_size, num_filters].
""" """
roi_features = spatial_transform_ops.multilevel_crop_and_resize( roi_features = spatial_transform_ops.multilevel_crop_and_resize(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ROI generator.""" """Contains definitions of ROI generator."""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -48,46 +48,48 @@ def _multilevel_propose_rois(raw_boxes, ...@@ -48,46 +48,48 @@ def _multilevel_propose_rois(raw_boxes,
3. Apply an overall top k to generate the final selected RoIs. 3. Apply an overall top k to generate the final selected RoIs.
Args: Args:
raw_boxes: a dict with keys representing FPN levels and values representing raw_boxes: A `dict` with keys representing FPN levels and values
box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4]. representing box tenors of shape
raw_scores: a dict with keys representing FPN levels and values representing [batch_size, feature_h, feature_w, num_anchors * 4].
logit tensors of shape [batch_size, feature_h, feature_w, num_anchors]. raw_scores: A `dict` with keys representing FPN levels and values
anchor_boxes: a dict with keys representing FPN levels and values representing logit tensors of shape
[batch_size, feature_h, feature_w, num_anchors].
anchor_boxes: A `dict` with keys representing FPN levels and values
representing anchor box tensors of shape representing anchor box tensors of shape
[batch_size, feature_h * feature_w * num_anchors, 4]. [batch_size, feature_h * feature_w * num_anchors, 4].
image_shape: a tensor of shape [batch_size, 2] where the last dimension are image_shape: A `tf.Tensor` of shape [batch_size, 2] where the last dimension
[height, width] of the scaled image. are [height, width] of the scaled image.
pre_nms_top_k: an integer of top scoring RPN proposals *per level* to pre_nms_top_k: An `int` of top scoring RPN proposals *per level* to keep
keep before applying NMS. Default: 2000. before applying NMS. Default: 2000.
pre_nms_score_threshold: a float between 0 and 1 representing the minimal pre_nms_score_threshold: A `float` between 0 and 1 representing the minimal
box score to keep before applying NMS. This is often used as a box score to keep before applying NMS. This is often used as a
pre-filtering step for better performance. Default: 0, no filtering is pre-filtering step for better performance. Default: 0, no filtering is
applied. applied.
pre_nms_min_size_threshold: a float representing the minimal box size in pre_nms_min_size_threshold: A `float` representing the minimal box size in
each side (w.r.t. the scaled image) to keep before applying NMS. This is each side (w.r.t. the scaled image) to keep before applying NMS. This is
often used as a pre-filtering step for better performance. Default: 0, no often used as a pre-filtering step for better performance. Default: 0, no
filtering is applied. filtering is applied.
nms_iou_threshold: a float between 0 and 1 representing the IoU threshold nms_iou_threshold: A `float` between 0 and 1 representing the IoU threshold
used for NMS. If 0.0, no NMS is applied. Default: 0.7. used for NMS. If 0.0, no NMS is applied. Default: 0.7.
num_proposals: an integer of top scoring RPN proposals *in total* to num_proposals: An `int` of top scoring RPN proposals *in total* to keep
keep after applying NMS. Default: 1000. after applying NMS. Default: 1000.
use_batched_nms: a boolean indicating whether NMS is applied in batch using use_batched_nms: A `bool` indicating whether NMS is applied in batch using
`tf.image.combined_non_max_suppression`. Currently only available in `tf.image.combined_non_max_suppression`. Currently only available in
CPU/GPU. Default: False. CPU/GPU. Default is False.
decode_boxes: a boolean indicating whether `raw_boxes` needs to be decoded decode_boxes: A `bool` indicating whether `raw_boxes` needs to be decoded
using `anchor_boxes`. If False, use `raw_boxes` directly and ignore using `anchor_boxes`. If False, use `raw_boxes` directly and ignore
`anchor_boxes`. Default: True. `anchor_boxes`. Default is True.
clip_boxes: a boolean indicating whether boxes are first clipped to the clip_boxes: A `bool` indicating whether boxes are first clipped to the
scaled image size before appliying NMS. If False, no clipping is applied scaled image size before appliying NMS. If False, no clipping is applied
and `image_shape` is ignored. Default: True. and `image_shape` is ignored. Default is True.
apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to apply_sigmoid_to_score: A `bool` indicating whether apply sigmoid to
`raw_scores` before applying NMS. Default: True. `raw_scores` before applying NMS. Default is True.
Returns: Returns:
selected_rois: a tensor of shape [batch_size, num_proposals, 4], selected_rois: A `tf.Tensor` of shape [batch_size, num_proposals, 4],
representing the box coordinates of the selected proposals w.r.t. the representing the box coordinates of the selected proposals w.r.t. the
scaled image. scaled image.
selected_roi_scores: a tensor of shape [batch_size, num_proposals, 1], selected_roi_scores: A `tf.Tensor` of shape [batch_size, num_proposals, 1],
representing the scores of the selected proposals. representing the scores of the selected proposals.
""" """
with tf.name_scope('multilevel_propose_rois'): with tf.name_scope('multilevel_propose_rois'):
...@@ -196,30 +198,31 @@ class MultilevelROIGenerator(tf.keras.layers.Layer): ...@@ -196,30 +198,31 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
The ROI generator transforms the raw predictions from RPN to ROIs. The ROI generator transforms the raw predictions from RPN to ROIs.
Args: Args:
pre_nms_top_k: int, the number of top scores proposals to be kept before pre_nms_top_k: An `int` of the number of top scores proposals to be kept
applying NMS. before applying NMS.
pre_nms_score_threshold: float, the score threshold to apply before pre_nms_score_threshold: A `float` of the score threshold to apply before
applying NMS. Proposals whose scores are below this threshold are applying NMS. Proposals whose scores are below this threshold are
thrown away. thrown away.
pre_nms_min_size_threshold: float, the threshold of each side of the box pre_nms_min_size_threshold: A `float` of the threshold of each side of the
(w.r.t. the scaled image). Proposals whose sides are below this box (w.r.t. the scaled image). Proposals whose sides are below this
threshold are thrown away. threshold are thrown away.
nms_iou_threshold: float in [0, 1], the NMS IoU threshold. nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
num_proposals: int, the final number of proposals to generate. num_proposals: An `int` of the final number of proposals to generate.
test_pre_nms_top_k: int, the number of top scores proposals to be kept test_pre_nms_top_k: An `int` of the number of top scores proposals to be
before applying NMS in testing. kept before applying NMS in testing.
test_pre_nms_score_threshold: float, the score threshold to apply before test_pre_nms_score_threshold: A `float` of the score threshold to apply
applying NMS in testing. Proposals whose scores are below this threshold before applying NMS in testing. Proposals whose scores are below this
are thrown away. threshold are thrown away.
test_pre_nms_min_size_threshold: float, the threshold of each side of the test_pre_nms_min_size_threshold: A `float` of the threshold of each side
box (w.r.t. the scaled image) in testing. Proposals whose sides are of the box (w.r.t. the scaled image) in testing. Proposals whose sides
below this threshold are thrown away. are below this threshold are thrown away.
test_nms_iou_threshold: float in [0, 1], the NMS IoU threshold in testing. test_nms_iou_threshold: A `float` in [0, 1] of the NMS IoU threshold in
test_num_proposals: int, the final number of proposals to generate in
testing. testing.
use_batched_nms: bool, whether or not use test_num_proposals: An `int` of the final number of proposals to generate
in testing.
use_batched_nms: A `bool` of whether or not use
`tf.image.combined_non_max_suppression`. `tf.image.combined_non_max_suppression`.
**kwargs: other key word arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
'pre_nms_top_k': pre_nms_top_k, 'pre_nms_top_k': pre_nms_top_k,
...@@ -257,23 +260,24 @@ class MultilevelROIGenerator(tf.keras.layers.Layer): ...@@ -257,23 +260,24 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
3. Apply an overall top k to generate the final selected RoIs. 3. Apply an overall top k to generate the final selected RoIs.
Args: Args:
raw_boxes: a dict with keys representing FPN levels and values raw_boxes: A `dict` with keys representing FPN levels and values
representing box tenors of shape representing box tenors of shape
[batch, feature_h, feature_w, num_anchors * 4]. [batch, feature_h, feature_w, num_anchors * 4].
raw_scores: a dict with keys representing FPN levels and values raw_scores: A `dict` with keys representing FPN levels and values
representing logit tensors of shape representing logit tensors of shape
[batch, feature_h, feature_w, num_anchors]. [batch, feature_h, feature_w, num_anchors].
anchor_boxes: a dict with keys representing FPN levels and values anchor_boxes: A `dict` with keys representing FPN levels and values
representing anchor box tensors of shape representing anchor box tensors of shape
[batch, feature_h * feature_w * num_anchors, 4]. [batch, feature_h * feature_w * num_anchors, 4].
image_shape: a tensor of shape [batch, 2] where the last dimension are image_shape: A `tf.Tensor` of shape [batch, 2] where the last dimension
[height, width] of the scaled image. are [height, width] of the scaled image.
training: a bool indicat whether it is in training mode. training: A `bool` that indicates whether it is in training mode.
Returns: Returns:
roi_boxes: [batch, num_proposals, 4], the proposed ROIs in the scaled roi_boxes: A `tf.Tensor` of shape [batch, num_proposals, 4], the proposed
image coordinate. ROIs in the scaled image coordinate.
roi_scores: [batch, num_proposals], scores of the proposed ROIs. roi_scores: A `tf.Tensor` of shape [batch, num_proposals], scores of the
proposed ROIs.
""" """
roi_boxes, roi_scores = _multilevel_propose_rois( roi_boxes, roi_scores = _multilevel_propose_rois(
raw_boxes, raw_boxes,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ROI sampler.""" """Contains definitions of ROI sampler."""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -23,7 +23,7 @@ from official.vision.beta.modeling.layers import box_sampler ...@@ -23,7 +23,7 @@ from official.vision.beta.modeling.layers import box_sampler
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class ROISampler(tf.keras.layers.Layer): class ROISampler(tf.keras.layers.Layer):
"""Sample ROIs and assign targets to the sampled ROIs.""" """Samples ROIs and assigns targets to the sampled ROIs."""
def __init__(self, def __init__(self,
mix_gt_boxes=True, mix_gt_boxes=True,
...@@ -36,20 +36,20 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -36,20 +36,20 @@ class ROISampler(tf.keras.layers.Layer):
"""Initializes a ROI sampler. """Initializes a ROI sampler.
Args: Args:
mix_gt_boxes: bool, whether to mix the groundtruth boxes with proposed mix_gt_boxes: A `bool` of whether to mix the groundtruth boxes with
ROIs. proposed ROIs.
num_sampled_rois: int, the number of sampled ROIs per image. num_sampled_rois: An `int` of the number of sampled ROIs per image.
foreground_fraction: float in [0, 1], what percentage of proposed ROIs foreground_fraction: A `float` in [0, 1], what percentage of proposed ROIs
should be sampled from the foreground boxes. should be sampled from the foreground boxes.
foreground_iou_threshold: float, represent the IoU threshold for a box to foreground_iou_threshold: A `float` that represents the IoU threshold for
be considered as positive (if >= `foreground_iou_threshold`). a box to be considered as positive (if >= `foreground_iou_threshold`).
background_iou_high_threshold: float, represent the IoU threshold for a background_iou_high_threshold: A `float` that represents the IoU threshold
box to be considered as negative (if overlap in for a box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`]). [`background_iou_low_threshold`, `background_iou_high_threshold`]).
background_iou_low_threshold: float, represent the IoU threshold for a box background_iou_low_threshold: A `float` that represents the IoU threshold
to be considered as negative (if overlap in for a box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`]) [`background_iou_low_threshold`, `background_iou_high_threshold`])
**kwargs: other key word arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
'mix_gt_boxes': mix_gt_boxes, 'mix_gt_boxes': mix_gt_boxes,
...@@ -85,29 +85,30 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -85,29 +85,30 @@ class ROISampler(tf.keras.layers.Layer):
returns box_targets, class_targets, and RoIs. returns box_targets, class_targets, and RoIs.
Args: Args:
boxes: a tensor of shape of [batch_size, N, 4]. N is the number of boxes: A `tf.Tensor` of shape of [batch_size, N, 4]. N is the number of
proposals before groundtruth assignment. The last dimension is the proposals before groundtruth assignment. The last dimension is the
box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax] box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
format. format.
gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4]. gt_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES, 4].
The coordinates of gt_boxes are in the pixel coordinates of the scaled The coordinates of gt_boxes are in the pixel coordinates of the scaled
image. This tensor might have padding of values -1 indicating the image. This tensor might have padding of values -1 indicating the
invalid box coordinates. invalid box coordinates.
gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This gt_classes: A `tf.Tensor` with a shape of [batch_size, MAX_NUM_INSTANCES].
tensor might have paddings with values of -1 indicating the invalid This tensor might have paddings with values of -1 indicating the invalid
classes. classes.
Returns: Returns:
sampled_rois: a tensor of shape of [batch_size, K, 4], representing the sampled_rois: A `tf.Tensor` of shape of [batch_size, K, 4], representing
coordinates of the sampled RoIs, where K is the number of the sampled the coordinates of the sampled RoIs, where K is the number of the
RoIs, i.e. K = num_samples_per_image. sampled RoIs, i.e. K = num_samples_per_image.
sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the sampled_gt_boxes: A `tf.Tensor` of shape of [batch_size, K, 4], storing
box coordinates of the matched groundtruth boxes of the samples RoIs. the box coordinates of the matched groundtruth boxes of the samples
sampled_gt_classes: a tensor of shape of [batch_size, K], storing the RoIs.
sampled_gt_classes: A `tf.Tensor` of shape of [batch_size, K], storing the
classes of the matched groundtruth boxes of the sampled RoIs. classes of the matched groundtruth boxes of the sampled RoIs.
sampled_gt_indices: a tensor of shape of [batch_size, K], storing the sampled_gt_indices: A `tf.Tensor` of shape of [batch_size, K], storing the
indices of the sampled groudntruth boxes in the original `gt_boxes` indices of the sampled groudntruth boxes in the original `gt_boxes`
tensor, i.e. tensor, i.e.,
gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i]. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
""" """
if self._config_dict['mix_gt_boxes']: if self._config_dict['mix_gt_boxes']:
......
...@@ -20,7 +20,6 @@ task: ...@@ -20,7 +20,6 @@ task:
tfds_name: 'imagenet2012' tfds_name: 'imagenet2012'
tfds_split: 'train' tfds_split: 'train'
tfds_data_dir: '~/tensorflow_datasets/' tfds_data_dir: '~/tensorflow_datasets/'
tfds_download: true
is_training: true is_training: true
global_batch_size: 16 # default = 128 global_batch_size: 16 # default = 128
dtype: 'float16' dtype: 'float16'
...@@ -29,7 +28,6 @@ task: ...@@ -29,7 +28,6 @@ task:
tfds_name: 'imagenet2012' tfds_name: 'imagenet2012'
tfds_split: 'validation' tfds_split: 'validation'
tfds_data_dir: '~/tensorflow_datasets/' tfds_data_dir: '~/tensorflow_datasets/'
tfds_download: true
is_training: true is_training: true
global_batch_size: 16 # default = 128 global_batch_size: 16 # default = 128
dtype: 'float16' dtype: 'float16'
......
...@@ -20,7 +20,6 @@ task: ...@@ -20,7 +20,6 @@ task:
tfds_name: 'imagenet2012' tfds_name: 'imagenet2012'
tfds_split: 'train' tfds_split: 'train'
tfds_data_dir: '~/tensorflow_datasets/' tfds_data_dir: '~/tensorflow_datasets/'
tfds_download: true
is_training: true is_training: true
global_batch_size: 16 # default = 128 global_batch_size: 16 # default = 128
dtype: 'float16' dtype: 'float16'
...@@ -29,7 +28,6 @@ task: ...@@ -29,7 +28,6 @@ task:
tfds_name: 'imagenet2012' tfds_name: 'imagenet2012'
tfds_split: 'validation' tfds_split: 'validation'
tfds_data_dir: '~/tensorflow_datasets/' tfds_data_dir: '~/tensorflow_datasets/'
tfds_download: true
is_training: true is_training: true
global_batch_size: 16 # default = 128 global_batch_size: 16 # default = 128
dtype: 'float16' dtype: 'float16'
......
...@@ -52,7 +52,6 @@ class DataConfig(cfg.DataConfig): ...@@ -52,7 +52,6 @@ class DataConfig(cfg.DataConfig):
decoder = None decoder = None
parser: Parser = Parser() parser: Parser = Parser()
shuffle_buffer_size: int = 10 shuffle_buffer_size: int = 10
tfds_download: bool = False
class YoloDetectionInputTest(tf.test.TestCase, parameterized.TestCase): class YoloDetectionInputTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -51,7 +51,8 @@ def main(_): ...@@ -51,7 +51,8 @@ def main(_):
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -31,32 +31,30 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -31,32 +31,30 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class DetectionModule(export_base.ExportModule): class DetectionModule(export_base.ExportModule):
"""Detection Module.""" """Detection Module."""
def build_model(self): def _build_model(self):
if self._batch_size is None: if self._batch_size is None:
ValueError("batch_size can't be None for detection models") ValueError("batch_size can't be None for detection models")
if not self._params.task.model.detection_generator.use_batched_nms: if not self.params.task.model.detection_generator.use_batched_nms:
ValueError('Only batched_nms is supported.') ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3]) self._input_image_size + [3])
if isinstance(self._params.task.model, configs.maskrcnn.MaskRCNN): if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
self._model = factory.build_maskrcnn( model = factory.build_maskrcnn(
input_specs=input_specs, input_specs=input_specs, model_config=self.params.task.model)
model_config=self._params.task.model) elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
elif isinstance(self._params.task.model, configs.retinanet.RetinaNet): model = factory.build_retinanet(
self._model = factory.build_retinanet( input_specs=input_specs, model_config=self.params.task.model)
input_specs=input_specs,
model_config=self._params.task.model)
else: else:
raise ValueError('Detection module not implemented for {} model.'.format( raise ValueError('Detection module not implemented for {} model.'.format(
type(self._params.task.model))) type(self.params.task.model)))
return self._model return model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds detection model inputs for serving.""" """Builds detection model inputs for serving."""
model_params = self._params.task.model model_params = self.params.task.model
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB, offset=MEAN_RGB,
...@@ -81,7 +79,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -81,7 +79,7 @@ class DetectionModule(export_base.ExportModule):
return image, anchor_boxes, image_info return image, anchor_boxes, image_info
def _run_inference_on_image_tensors(self, images: tf.Tensor): def serve(self, images: tf.Tensor):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
...@@ -89,7 +87,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -89,7 +87,7 @@ class DetectionModule(export_base.ExportModule):
Returns: Returns:
Tensor holding detection output logits. Tensor holding detection output logits.
""" """
model_params = self._params.task.model model_params = self.params.task.model
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32) images = tf.cast(images, dtype=tf.float32)
...@@ -122,7 +120,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -122,7 +120,7 @@ class DetectionModule(export_base.ExportModule):
input_image_shape = image_info[:, 1, :] input_image_shape = image_info[:, 1, :]
detections = self._model.call( detections = self.model.call(
images=images, images=images,
image_shape=input_image_shape, image_shape=input_image_shape,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
......
...@@ -38,35 +38,10 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -38,35 +38,10 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
params, batch_size=1, input_image_size=[640, 640]) params, batch_size=1, input_image_size=[640, 640])
return detection_module return detection_module
def _export_from_module(self, module, input_type, batch_size, save_directory): def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor': signatures = module.get_inference_signatures(
input_signature = tf.TensorSpec( {input_type: 'serving_default'})
shape=[batch_size, None, None, 3], dtype=tf.uint8) tf.saved_model.save(module, save_directory, signatures=signatures)
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type, batch_size, image_size): def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type.""" """Get dummy input for the given input type."""
...@@ -107,23 +82,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -107,23 +82,23 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_export(self, input_type, experiment_name, image_size): def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
batch_size = 1
module = self._get_detection_module(experiment_name) module = self._get_detection_module(experiment_name)
model = module.build_model()
self._export_from_module(module, input_type, batch_size, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb'))) self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists( self.assertTrue(
os.path.join(tmp_dir, 'variables', 'variables.index'))) os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists( self.assertTrue(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001'))) os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir) imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default'] detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type, batch_size, image_size) images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
processed_images, anchor_boxes, image_info = module._build_inputs( processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8)) tf.zeros((224, 224, 3), dtype=tf.uint8))
...@@ -133,7 +108,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -133,7 +108,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
for l, l_boxes in anchor_boxes.items(): for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0) anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = model( expected_outputs = module.model(
images=processed_images, images=processed_images,
image_shape=image_shape, image_shape=image_shape,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
...@@ -143,5 +118,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -143,5 +118,6 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(outputs['num_detections'].numpy(), self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -16,20 +16,22 @@ ...@@ -16,20 +16,22 @@
"""Base class for model export.""" """Base class for model export."""
import abc import abc
from typing import Optional, Sequence, Mapping from typing import Dict, List, Mapping, Optional, Text
import tensorflow as tf import tensorflow as tf
from official.core import export_base
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
class ExportModule(tf.Module, metaclass=abc.ABCMeta): class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
"""Base Export Module.""" """Base Export Module."""
def __init__(self, def __init__(self,
params: cfg.ExperimentConfig, params: cfg.ExperimentConfig,
*,
batch_size: int, batch_size: int,
input_image_size: Sequence[int], input_image_size: List[int],
num_channels: int = 3, num_channels: int = 3,
model: Optional[tf.keras.Model] = None): model: Optional[tf.keras.Model] = None):
"""Initializes a module for export. """Initializes a module for export.
...@@ -42,13 +44,13 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -42,13 +44,13 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
num_channels: The number of the image channels. num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
""" """
self.params = params
super(ExportModule, self).__init__()
self._params = params
self._batch_size = batch_size self._batch_size = batch_size
self._input_image_size = input_image_size self._input_image_size = input_image_size
self._num_channels = num_channels self._num_channels = num_channels
self._model = model if model is None:
model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model)
def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor: def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
"""Decodes an image bytes to an image tensor. """Decodes an image bytes to an image tensor.
...@@ -92,45 +94,40 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -92,45 +94,40 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
image_tensor = self._decode_image(parsed_tensors['image/encoded']) image_tensor = self._decode_image(parsed_tensors['image/encoded'])
return image_tensor return image_tensor
@abc.abstractmethod def _build_model(self, **kwargs):
def build_model(self, **kwargs): """Returns a model built from the params."""
"""Builds model and sets self._model.""" return None
@abc.abstractmethod
def _run_inference_on_image_tensors(
self, images: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Runs inference on images."""
@tf.function @tf.function
def inference_from_image_tensors( def inference_from_image_tensors(
self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]: self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self._run_inference_on_image_tensors(input_tensor) return self.serve(inputs)
@tf.function @tf.function
def inference_from_image_bytes(self, input_tensor: str): def inference_from_image_bytes(self, inputs: tf.Tensor):
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._decode_image, self._decode_image,
elems=input_tensor, elems=inputs,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) + shape=[None] * len(self._input_image_size) +
[self._num_channels], [self._num_channels],
dtype=tf.uint8), dtype=tf.uint8),
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self.serve(images)
@tf.function @tf.function
def inference_from_tf_example( def inference_from_tf_example(self,
self, input_tensor: tf.train.Example) -> Mapping[str, tf.Tensor]: inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._decode_tf_example, self._decode_tf_example,
elems=input_tensor, elems=inputs,
# Height/width of the shape of input images is unspecified (None) # Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will # at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model, # be adjusted to conform to the input layer of the model,
...@@ -142,4 +139,41 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -142,4 +139,41 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
dtype=tf.uint8, dtype=tf.uint8,
parallel_iterations=32)) parallel_iterations=32))
images = tf.stack(images) images = tf.stack(images)
return self._run_inference_on_image_tensors(images) return self.serve(images)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8)
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'image_bytes':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_image_bytes.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
...@@ -16,16 +16,15 @@ ...@@ -16,16 +16,15 @@
r"""Vision models export utility function for serving/inference.""" r"""Vision models export utility function for serving/inference."""
import os import os
from typing import Optional, List from typing import Optional, List
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils from official.core import train_utils
from official.vision.beta import configs from official.vision.beta import configs
from official.vision.beta.serving import detection from official.vision.beta.serving import detection
from official.vision.beta.serving import export_base
from official.vision.beta.serving import image_classification from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation from official.vision.beta.serving import semantic_segmentation
...@@ -75,6 +74,7 @@ def export_inference_graph( ...@@ -75,6 +74,7 @@ def export_inference_graph(
else: else:
output_saved_model_directory = export_dir output_saved_model_directory = export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if not export_module: if not export_module:
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
...@@ -101,47 +101,13 @@ def export_inference_graph( ...@@ -101,47 +101,13 @@ def export_inference_graph(
raise ValueError('Export module not implemented for {} task.'.format( raise ValueError('Export module not implemented for {} task.'.format(
type(params.task))) type(params.task)))
model = export_module.build_model() export_base.export(
export_module,
ckpt = tf.train.Checkpoint(model=model) function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
ckpt_dir_or_file = checkpoint_path checkpoint_path=checkpoint_path,
if tf.io.gfile.isdir(ckpt_dir_or_file): timestamped=False)
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
status = ckpt.restore(ckpt_dir_or_file).expect_partial()
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size] + [None] * len(input_image_size) + [num_channels],
dtype=tf.uint8)
signatures = {
'serving_default':
export_module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
status.assert_existing_objects_matched()
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt')) ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
tf.saved_model.save(export_module,
output_saved_model_directory,
signatures=signatures)
train_utils.serialize_config(params, export_dir) train_utils.serialize_config(params, export_dir)
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.serving import image_classification from official.vision.beta.modeling import factory
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -68,10 +68,14 @@ def export_model_to_tfhub(params, ...@@ -68,10 +68,14 @@ def export_model_to_tfhub(params,
checkpoint_path, checkpoint_path,
export_path): export_path):
"""Export an image classification model to TF-Hub.""" """Export an image classification model to TF-Hub."""
export_module = image_classification.ClassificationModule( input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
params=params, batch_size=batch_size, input_image_size=input_image_size) input_image_size + [3])
model = export_module.build_model(skip_logits_layer=skip_logits_layer) model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(checkpoint_path).assert_existing_objects_matched() checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
model.save(export_path, include_optimizer=False, save_format='tf') model.save(export_path, include_optimizer=False, save_format='tf')
......
...@@ -29,17 +29,14 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -29,17 +29,14 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class ClassificationModule(export_base.ExportModule): class ClassificationModule(export_base.ExportModule):
"""classification Module.""" """classification Module."""
def build_model(self, skip_logits_layer=False): def _build_model(self):
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3]) shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_classification_model( return factory.build_classification_model(
input_specs=input_specs, input_specs=input_specs,
model_config=self._params.task.model, model_config=self.params.task.model,
l2_regularizer=None, l2_regularizer=None)
skip_logits_layer=skip_logits_layer)
return self._model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds classification model inputs for serving.""" """Builds classification model inputs for serving."""
...@@ -58,7 +55,7 @@ class ClassificationModule(export_base.ExportModule): ...@@ -58,7 +55,7 @@ class ClassificationModule(export_base.ExportModule):
scale=STDDEV_RGB) scale=STDDEV_RGB)
return image return image
def _run_inference_on_image_tensors(self, images): def serve(self, images):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
...@@ -79,6 +76,6 @@ class ClassificationModule(export_base.ExportModule): ...@@ -79,6 +76,6 @@ class ClassificationModule(export_base.ExportModule):
) )
) )
logits = self._model(images, training=False) logits = self.inference_step(images)
return dict(outputs=logits) return dict(outputs=logits)
...@@ -38,30 +38,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -38,30 +38,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
return classification_module return classification_module
def _export_from_module(self, module, input_type, save_directory): def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor': signatures = module.get_inference_signatures(
input_signature = tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8) {input_type: 'serving_default'})
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module, tf.saved_model.save(module,
save_directory, save_directory,
signatures=signatures) signatures=signatures)
...@@ -95,9 +73,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,9 +73,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
) )
def test_export(self, input_type='image_tensor'): def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
module = self._get_classification_module() module = self._get_classification_module()
model = module.build_model()
self._export_from_module(module, input_type, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
...@@ -118,7 +94,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -118,7 +94,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8), elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32))) shape=[224, 224, 3], dtype=tf.float32)))
expected_output = model(processed_images, training=False) expected_output = module.model(processed_images, training=False)
out = classification_fn(tf.constant(images)) out = classification_fn(tf.constant(images))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy()) self.assertAllClose(out['outputs'].numpy(), expected_output.numpy())
......
...@@ -29,17 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -29,17 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class SegmentationModule(export_base.ExportModule): class SegmentationModule(export_base.ExportModule):
"""Segmentation Module.""" """Segmentation Module."""
def build_model(self): def _build_model(self):
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3]) shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_segmentation_model( return factory.build_segmentation_model(
input_specs=input_specs, input_specs=input_specs,
model_config=self._params.task.model, model_config=self.params.task.model,
l2_regularizer=None) l2_regularizer=None)
return self._model
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds classification model inputs for serving.""" """Builds classification model inputs for serving."""
...@@ -56,7 +54,7 @@ class SegmentationModule(export_base.ExportModule): ...@@ -56,7 +54,7 @@ class SegmentationModule(export_base.ExportModule):
aug_scale_max=1.0) aug_scale_max=1.0)
return image return image
def _run_inference_on_image_tensors(self, images): def serve(self, images):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
...@@ -77,7 +75,7 @@ class SegmentationModule(export_base.ExportModule): ...@@ -77,7 +75,7 @@ class SegmentationModule(export_base.ExportModule):
) )
) )
masks = self._model(images, training=False) masks = self.inference_step(images)
masks = tf.image.resize(masks, self._input_image_size, method='bilinear') masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
return dict(predicted_masks=masks) return dict(predicted_masks=masks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment