Commit 460890ed authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406888835
parent f2bc366e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for centernet neural networks."""
from typing import List, Optional
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_blocks
def _apply_blocks(inputs, blocks):
"""Apply blocks to inputs."""
net = inputs
for block in blocks:
net = block(net)
return net
def _make_repeated_residual_blocks(
reps: int,
out_channels: int,
use_sync_bn: bool = True,
norm_momentum: float = 0.1,
norm_epsilon: float = 1e-5,
residual_channels: Optional[int] = None,
initial_stride: int = 1,
initial_skip_conv: bool = False,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
):
"""Stack Residual blocks one after the other.
Args:
reps: `int` for desired number of residual blocks
out_channels: `int`, filter depth of the final residual block
use_sync_bn: A `bool`, if True, use synchronized batch normalization.
norm_momentum: `float`, momentum for the batch normalization layers
norm_epsilon: `float`, epsilon for the batch normalization layers
residual_channels: `int`, filter depth for the first reps - 1 residual
blocks. If None, defaults to the same value as out_channels. If not
equal to out_channels, then uses a projection shortcut in the final
residual block
initial_stride: `int`, stride for the first residual block
initial_skip_conv: `bool`, if set, the first residual block uses a skip
convolution. This is useful when the number of channels in the input
are not the same as residual_channels.
kernel_initializer: A `str` for kernel initializer of convolutional layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
Returns:
blocks: A list of residual blocks to be applied in sequence.
"""
blocks = []
if residual_channels is None:
residual_channels = out_channels
for i in range(reps - 1):
# Only use the stride at the first block so we don't repeatedly downsample
# the input
stride = initial_stride if i == 0 else 1
# If the stride is more than 1, we cannot use an identity layer for the
# skip connection and are forced to use a conv for the skip connection.
skip_conv = stride > 1
if i == 0 and initial_skip_conv:
skip_conv = True
blocks.append(nn_blocks.ResidualBlock(
filters=residual_channels,
strides=stride,
use_explicit_padding=True,
use_projection=skip_conv,
use_sync_bn=use_sync_bn,
norm_momentum=norm_momentum,
norm_epsilon=norm_epsilon,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer))
if reps == 1:
# If there is only 1 block, the `for` loop above is not run,
# therefore we honor the requested stride in the last residual block
stride = initial_stride
# We are forced to use a conv in the skip connection if stride > 1
skip_conv = stride > 1
else:
stride = 1
skip_conv = residual_channels != out_channels
blocks.append(nn_blocks.ResidualBlock(
filters=out_channels,
strides=stride,
use_explicit_padding=True,
use_projection=skip_conv,
use_sync_bn=use_sync_bn,
norm_momentum=norm_momentum,
norm_epsilon=norm_epsilon,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer))
return tf.keras.Sequential(blocks)
@tf.keras.utils.register_keras_serializable(package='centernet')
class HourglassBlock(tf.keras.layers.Layer):
"""Hourglass module: an encoder-decoder block."""
def __init__(
self,
channel_dims_per_stage: List[int],
blocks_per_stage: List[int],
strides: int = 1,
use_sync_bn: bool = True,
norm_momentum: float = 0.1,
norm_epsilon: float = 1e-5,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initialize Hourglass module.
Args:
channel_dims_per_stage: List[int], list of filter sizes for Residual
blocks. the output channels dimensions of stages in
the network. `channel_dims[0]` is used to define the number of
channels in the first encoder block and `channel_dims[1]` is used to
define the number of channels in the second encoder block. The channels
in the recursive inner layers are defined using `channel_dims[1:]`.
For example, [nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4]
where nc is the input_channel_dimension.
blocks_per_stage: List[int], list of residual block repetitions per
down/upsample. `blocks_per_stage[0]` defines the number of blocks at the
current stage and `blocks_per_stage[1:]` is used at further stages.
For example, [2, 2, 2, 2, 2, 4].
strides: `int`, stride parameter to the Residual block.
use_sync_bn: A `bool`, if True, use synchronized batch normalization.
norm_momentum: `float`, momentum for the batch normalization layers.
norm_epsilon: `float`, epsilon for the batch normalization layers.
kernel_initializer: A `str` for kernel initializer of conv layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
**kwargs: Additional keyword arguments to be passed.
"""
super(HourglassBlock, self).__init__(**kwargs)
if len(channel_dims_per_stage) != len(blocks_per_stage):
raise ValueError('filter size and residual block repetition '
'lists must have the same length')
self._num_stages = len(channel_dims_per_stage) - 1
self._channel_dims_per_stage = channel_dims_per_stage
self._blocks_per_stage = blocks_per_stage
self._strides = strides
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._filters = channel_dims_per_stage[0]
if self._num_stages > 0:
self._filters_downsampled = channel_dims_per_stage[1]
self._reps = blocks_per_stage[0]
def build(self, input_shape):
if self._num_stages == 0:
# base case, residual block repetitions in most inner part of hourglass
self.blocks = _make_repeated_residual_blocks(
reps=self._reps,
out_channels=self._filters,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
bias_regularizer=self._bias_regularizer,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer)
else:
# outer hourglass structures
self.encoder_block1 = _make_repeated_residual_blocks(
reps=self._reps,
out_channels=self._filters,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
bias_regularizer=self._bias_regularizer,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer)
self.encoder_block2 = _make_repeated_residual_blocks(
reps=self._reps,
out_channels=self._filters_downsampled,
initial_stride=2,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
bias_regularizer=self._bias_regularizer,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
initial_skip_conv=self._filters != self._filters_downsampled)
# recursively define inner hourglasses
self.inner_hg = type(self)(
channel_dims_per_stage=self._channel_dims_per_stage[1:],
blocks_per_stage=self._blocks_per_stage[1:],
strides=self._strides)
# outer hourglass structures
self.decoder_block = _make_repeated_residual_blocks(
reps=self._reps,
residual_channels=self._filters_downsampled,
out_channels=self._filters,
use_sync_bn=self._use_sync_bn,
norm_epsilon=self._norm_epsilon,
bias_regularizer=self._bias_regularizer,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer)
self.upsample_layer = tf.keras.layers.UpSampling2D(
size=2,
interpolation='nearest')
super(HourglassBlock, self).build(input_shape)
def call(self, x, training=None):
if self._num_stages == 0:
return self.blocks(x)
else:
encoded_outputs = self.encoder_block1(x)
encoded_downsampled_outputs = self.encoder_block2(x)
inner_outputs = self.inner_hg(encoded_downsampled_outputs)
hg_output = self.decoder_block(inner_outputs)
return self.upsample_layer(hg_output) + encoded_outputs
def get_config(self):
config = {
'channel_dims_per_stage': self._channel_dims_per_stage,
'blocks_per_stage': self._blocks_per_stage,
'strides': self._strides,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
}
config.update(super(HourglassBlock, self).get_config())
return config
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetHeadConv(tf.keras.layers.Layer):
"""Convolution block for the CenterNet head."""
def __init__(self,
output_filters: int,
bias_init: float,
name: str,
**kwargs):
"""Initialize CenterNet head.
Args:
output_filters: `int`, channel depth of layer output
bias_init: `float`, value to initialize the bias vector for the final
convolution layer
name: `string`, layer name
**kwargs: Additional keyword arguments to be passed.
"""
super(CenterNetHeadConv, self).__init__(name=name, **kwargs)
self._output_filters = output_filters
self._bias_init = bias_init
def build(self, input_shape):
n_channels = input_shape[-1]
self.conv1 = tf.keras.layers.Conv2D(
filters=n_channels,
kernel_size=(3, 3),
padding='same')
self.relu = tf.keras.layers.ReLU()
# Initialize bias to the last Conv2D Layer
self.conv2 = tf.keras.layers.Conv2D(
filters=self._output_filters,
kernel_size=(1, 1),
padding='valid',
bias_initializer=tf.constant_initializer(self._bias_init))
super(CenterNetHeadConv, self).build(input_shape)
def call(self, x, training=None):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
return x
def get_config(self):
config = {
'output_filters': self._output_filters,
'bias_init': self._bias_init,
}
config.update(super(CenterNetHeadConv, self).get_config())
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Centernet nn_blocks.
It is a literal translation of the PyTorch implementation.
"""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks
class HourglassBlockPyTorch(tf.keras.layers.Layer):
"""An CornerNet-style implementation of the hourglass block."""
def __init__(self, dims, modules, k=0, **kwargs):
"""An CornerNet-style implementation of the hourglass block.
Args:
dims: input sizes of residual blocks
modules: number of repetitions of the residual blocks in each hourglass
upsampling and downsampling
k: recursive parameter
**kwargs: Additional keyword arguments to be passed.
"""
super(HourglassBlockPyTorch).__init__()
if len(dims) != len(modules):
raise ValueError('dims and modules lists must have the same length')
self.n = len(dims) - 1
self.k = k
self.modules = modules
self.dims = dims
self._kwargs = kwargs
def build(self, input_shape):
modules = self.modules
dims = self.dims
k = self.k
kwargs = self._kwargs
curr_mod = modules[k]
next_mod = modules[k + 1]
curr_dim = dims[k + 0]
next_dim = dims[k + 1]
self.up1 = self.make_up_layer(3, curr_dim, curr_dim, curr_mod, **kwargs)
self.max1 = tf.keras.layers.MaxPool2D(strides=2)
self.low1 = self.make_hg_layer(3, curr_dim, next_dim, curr_mod, **kwargs)
if self.n - k > 1:
self.low2 = type(self)(dims, modules, k=k + 1, **kwargs)
else:
self.low2 = self.make_low_layer(
3, next_dim, next_dim, next_mod, **kwargs)
self.low3 = self.make_hg_layer_revr(
3, next_dim, curr_dim, curr_mod, **kwargs)
self.up2 = tf.keras.layers.UpSampling2D(2)
self.merge = tf.keras.layers.Add()
super(HourglassBlockPyTorch, self).build(input_shape)
def call(self, x):
up1 = self.up1(x)
max1 = self.max1(x)
low1 = self.low1(max1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return self.merge([up1, up2])
def make_layer(self, k, inp_dim, out_dim, modules, **kwargs):
layers = [
nn_blocks.ResidualBlock(out_dim, 1, use_projection=True, **kwargs)]
for _ in range(1, modules):
layers.append(nn_blocks.ResidualBlock(out_dim, 1, **kwargs))
return tf.keras.Sequential(layers)
def make_layer_revr(self, k, inp_dim, out_dim, modules, **kwargs):
layers = []
for _ in range(modules - 1):
layers.append(
nn_blocks.ResidualBlock(inp_dim, 1, **kwargs))
layers.append(
nn_blocks.ResidualBlock(out_dim, 1, use_projection=True, **kwargs))
return tf.keras.Sequential(layers)
def make_up_layer(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer(k, inp_dim, out_dim, modules, **kwargs)
def make_low_layer(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer(k, inp_dim, out_dim, modules, **kwargs)
def make_hg_layer(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer(k, inp_dim, out_dim, modules, **kwargs)
def make_hg_layer_revr(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer_revr(k, inp_dim, out_dim, modules, **kwargs)
class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
def test_hourglass_block(self):
dims = [256, 256, 384, 384, 384, 512]
modules = [2, 2, 2, 2, 2, 4]
model = cn_nn_blocks.HourglassBlock(dims, modules)
test_input = tf.keras.Input((512, 512, 256))
_ = model(test_input)
filter_sizes = [256, 256, 384, 384, 384, 512]
rep_sizes = [2, 2, 2, 2, 2, 4]
hg_test_input_shape = (1, 512, 512, 256)
# bb_test_input_shape = (1, 512, 512, 3)
x_hg = tf.ones(shape=hg_test_input_shape)
# x_bb = tf.ones(shape=bb_test_input_shape)
hg = cn_nn_blocks.HourglassBlock(
channel_dims_per_stage=filter_sizes,
blocks_per_stage=rep_sizes)
hg.build(input_shape=hg_test_input_shape)
out = hg(x_hg)
self.assertAllEqual(
tf.shape(out), hg_test_input_shape,
'Hourglass module output shape and expected shape differ')
# ODAPI Test
layer = cn_nn_blocks.HourglassBlock(
blocks_per_stage=[2, 3, 4, 5, 6],
channel_dims_per_stage=[4, 6, 8, 10, 12])
output = layer(np.zeros((2, 64, 64, 4), dtype=np.float32))
self.assertEqual(output.shape, (2, 64, 64, 4))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection generator for centernet.
Parses predictions from the CenterNet head into the final bounding boxes,
confidences, and classes. This class contains repurposed methods from the
TensorFlow Object Detection API
in: https://github.com/tensorflow/models/blob/master/research/object_detection
/meta_architectures/center_net_meta_arch.py
"""
from typing import Any, Mapping
import tensorflow as tf
from official.vision.beta.ops import box_ops
from official.vision.beta.projects.centernet.ops import loss_ops
from official.vision.beta.projects.centernet.ops import nms_ops
@tf.keras.utils.register_keras_serializable(package='centernet')
class CenterNetDetectionGenerator(tf.keras.layers.Layer):
"""CenterNet Detection Generator."""
def __init__(self,
input_image_dims: int = 512,
net_down_scale: int = 4,
max_detections: int = 100,
peak_error: float = 1e-6,
peak_extract_kernel_size: int = 3,
class_offset: int = 1,
use_nms: bool = False,
nms_pre_thresh: float = 0.1,
nms_thresh: float = 0.4,
**kwargs):
"""Initialize CenterNet Detection Generator.
Args:
input_image_dims: An `int` that specifies the input image size.
net_down_scale: An `int` that specifies stride of the output.
max_detections: An `int` specifying the maximum number of bounding
boxes generated. This is an upper bound, so the number of generated
boxes may be less than this due to thresholding/non-maximum suppression.
peak_error: A `float` for determining non-valid heatmap locations to mask.
peak_extract_kernel_size: An `int` indicating the kernel size used when
performing max-pool over the heatmaps to detect valid center locations
from its neighbors. From the paper, set this to 3 to detect valid.
locations that have responses greater than its 8-connected neighbors
class_offset: An `int` indicating to add an offset to the class
prediction if the dataset labels have been shifted.
use_nms: A `bool` for whether or not to use non-maximum suppression to
filter the bounding boxes.
nms_pre_thresh: A `float` for pre-nms threshold.
nms_thresh: A `float` for nms threshold.
**kwargs: Additional keyword arguments to be passed.
"""
super(CenterNetDetectionGenerator, self).__init__(**kwargs)
# Object center selection parameters
self._max_detections = max_detections
self._peak_error = peak_error
self._peak_extract_kernel_size = peak_extract_kernel_size
# Used for adjusting class prediction
self._class_offset = class_offset
# Box normalization parameters
self._net_down_scale = net_down_scale
self._input_image_dims = input_image_dims
self._use_nms = use_nms
self._nms_pre_thresh = nms_pre_thresh
self._nms_thresh = nms_thresh
def process_heatmap(self,
feature_map: tf.Tensor,
kernel_size: int) -> tf.Tensor:
"""Processes the heatmap into peaks for box selection.
Given a heatmap, this function first masks out nearby heatmap locations of
the same class using max-pooling such that, ideally, only one center for the
object remains. Then, center locations are masked according to their scores
in comparison to a threshold. NOTE: Repurposed from Google OD API.
Args:
feature_map: A Tensor with shape [batch_size, height, width, num_classes]
which is the center heatmap predictions.
kernel_size: An integer value for max-pool kernel size.
Returns:
A Tensor with the same shape as the input but with non-valid center
prediction locations masked out.
"""
feature_map = tf.math.sigmoid(feature_map)
if not kernel_size or kernel_size == 1:
feature_map_peaks = feature_map
else:
feature_map_max_pool = tf.nn.max_pool(
feature_map,
ksize=kernel_size,
strides=1,
padding='SAME')
feature_map_peak_mask = tf.math.abs(
feature_map - feature_map_max_pool) < self._peak_error
# Zero out everything that is not a peak.
feature_map_peaks = (
feature_map * tf.cast(feature_map_peak_mask, feature_map.dtype))
return feature_map_peaks
def get_top_k_peaks(self,
feature_map_peaks: tf.Tensor,
batch_size: int,
width: int,
num_classes: int,
k: int = 100):
"""Gets the scores and indices of the top-k peaks from the feature map.
This function flattens the feature map in order to retrieve the top-k
peaks, then computes the x, y, and class indices for those scores.
NOTE: Repurposed from Google OD API.
Args:
feature_map_peaks: A `Tensor` with shape [batch_size, height,
width, num_classes] which is the processed center heatmap peaks.
batch_size: An `int` that indicates the batch size of the input.
width: An `int` that indicates the width (and also height) of the input.
num_classes: An `int` for the number of possible classes. This is also
the channel depth of the input.
k: `int`` that controls how many peaks to select.
Returns:
top_scores: A Tensor with shape [batch_size, k] containing the top-k
scores.
y_indices: A Tensor with shape [batch_size, k] containing the top-k
y-indices corresponding to top_scores.
x_indices: A Tensor with shape [batch_size, k] containing the top-k
x-indices corresponding to top_scores.
channel_indices: A Tensor with shape [batch_size, k] containing the top-k
channel indices corresponding to top_scores.
"""
# Flatten the entire prediction per batch
feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1])
# top_scores and top_indices have shape [batch_size, k]
top_scores, top_indices = tf.math.top_k(feature_map_peaks_flat, k=k)
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
y_indices, x_indices, channel_indices = (
loss_ops.get_row_col_channel_indices_from_flattened_indices(
top_indices, width, num_classes))
return top_scores, y_indices, x_indices, channel_indices
def get_boxes(self,
y_indices: tf.Tensor,
x_indices: tf.Tensor,
channel_indices: tf.Tensor,
height_width_predictions: tf.Tensor,
offset_predictions: tf.Tensor,
num_boxes: int):
"""Organizes prediction information into the final bounding boxes.
NOTE: Repurposed from Google OD API.
Args:
y_indices: A Tensor with shape [batch_size, k] containing the top-k
y-indices corresponding to top_scores.
x_indices: A Tensor with shape [batch_size, k] containing the top-k
x-indices corresponding to top_scores.
channel_indices: A Tensor with shape [batch_size, k] containing the top-k
channel indices corresponding to top_scores.
height_width_predictions: A Tensor with shape [batch_size, height,
width, 2] containing the object size predictions.
offset_predictions: A Tensor with shape [batch_size, height, width, 2]
containing the object local offset predictions.
num_boxes: `int`, the number of boxes.
Returns:
boxes: A Tensor with shape [batch_size, num_boxes, 4] that contains the
bounding box coordinates in [y_min, x_min, y_max, x_max] format.
detection_classes: A Tensor with shape [batch_size, num_boxes] that
gives the class prediction for each box.
num_detections: Number of non-zero confidence detections made.
"""
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
# shapes of heatmap output
shape = tf.shape(height_width_predictions)
batch_size, height, width = shape[0], shape[1], shape[2]
# combined indices dtype=int32
combined_indices = tf.stack([
loss_ops.multi_range(batch_size, value_repetitions=num_boxes),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
new_height_width = tf.gather_nd(height_width_predictions, combined_indices)
new_height_width = tf.reshape(new_height_width, [batch_size, num_boxes, 2])
height_width = tf.maximum(new_height_width, 0.0)
# height and widths dtype=float32
heights = height_width[..., 0]
widths = height_width[..., 1]
# Get the offsets of center points
new_offsets = tf.gather_nd(offset_predictions, combined_indices)
offsets = tf.reshape(new_offsets, [batch_size, num_boxes, 2])
# offsets are dtype=float32
y_offsets = offsets[..., 0]
x_offsets = offsets[..., 1]
y_indices = tf.cast(y_indices, dtype=heights.dtype)
x_indices = tf.cast(x_indices, dtype=widths.dtype)
detection_classes = channel_indices + self._class_offset
ymin = y_indices + y_offsets - heights / 2.0
xmin = x_indices + x_offsets - widths / 2.0
ymax = y_indices + y_offsets + heights / 2.0
xmax = x_indices + x_offsets + widths / 2.0
ymin = tf.clip_by_value(ymin, 0., tf.cast(height, ymin.dtype))
xmin = tf.clip_by_value(xmin, 0., tf.cast(width, xmin.dtype))
ymax = tf.clip_by_value(ymax, 0., tf.cast(height, ymax.dtype))
xmax = tf.clip_by_value(xmax, 0., tf.cast(width, xmax.dtype))
boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
return boxes, detection_classes
def convert_strided_predictions_to_normalized_boxes(self, boxes: tf.Tensor):
boxes = boxes * tf.cast(self._net_down_scale, boxes.dtype)
boxes = boxes / tf.cast(self._input_image_dims, boxes.dtype)
boxes = tf.clip_by_value(boxes, 0.0, 1.0)
return boxes
def __call__(self, inputs):
# Get heatmaps from decoded outputs via final hourglass stack output
all_ct_heatmaps = inputs['ct_heatmaps']
all_ct_sizes = inputs['ct_size']
all_ct_offsets = inputs['ct_offset']
ct_heatmaps = all_ct_heatmaps[-1]
ct_sizes = all_ct_sizes[-1]
ct_offsets = all_ct_offsets[-1]
shape = tf.shape(ct_heatmaps)
_, width = shape[1], shape[2]
batch_size, num_channels = shape[0], shape[3]
# Process heatmaps using 3x3 max pool and applying sigmoid
peaks = self.process_heatmap(
feature_map=ct_heatmaps,
kernel_size=self._peak_extract_kernel_size)
# Get top scores along with their x, y, and class
# Each has size [batch_size, k]
scores, y_indices, x_indices, channel_indices = self.get_top_k_peaks(
feature_map_peaks=peaks,
batch_size=batch_size,
width=width,
num_classes=num_channels,
k=self._max_detections)
# Parse the score and indices into bounding boxes
boxes, classes = self.get_boxes(
y_indices=y_indices,
x_indices=x_indices,
channel_indices=channel_indices,
height_width_predictions=ct_sizes,
offset_predictions=ct_offsets,
num_boxes=self._max_detections)
# Normalize bounding boxes
boxes = self.convert_strided_predictions_to_normalized_boxes(boxes)
# Apply nms
if self._use_nms:
boxes = tf.expand_dims(boxes, axis=-2)
multi_class_scores = tf.gather_nd(
peaks, tf.stack([y_indices, x_indices], -1), batch_dims=1)
boxes, _, scores = nms_ops.nms(
boxes=boxes,
classes=multi_class_scores,
confidence=scores,
k=self._max_detections,
limit_pre_thresh=True,
pre_nms_thresh=0.1,
nms_thresh=0.4)
num_det = tf.reduce_sum(tf.cast(scores > 0, dtype=tf.int32), axis=1)
boxes = box_ops.denormalize_boxes(
boxes, [self._input_image_dims, self._input_image_dims])
return {
'boxes': boxes,
'classes': classes,
'confidence': scores,
'num_detections': num_det
}
def get_config(self) -> Mapping[str, Any]:
config = {
'max_detections': self._max_detections,
'peak_error': self._peak_error,
'peak_extract_kernel_size': self._peak_extract_kernel_size,
'class_offset': self._class_offset,
'net_down_scale': self._net_down_scale,
'input_image_dims': self._input_image_dims,
'use_nms': self._use_nms,
'nms_pre_thresh': self._nms_pre_thresh,
'nms_thresh': self._nms_thresh
}
base_config = super(CenterNetDetectionGenerator, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bounding Box List definition.
BoxList represents a list of bounding boxes as tensorflow
tensors, where each bounding box is represented as a row of 4 numbers,
[y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes
within a given list correspond to a single image. See also
box_list_ops.py for common box related operations (such as area, iou, etc).
Optionally, users can add additional related fields (such as weights).
We assume the following things to be true about fields:
* they correspond to boxes in the box_list along the 0th dimension
* they have inferrable rank at graph construction time
* all dimensions except for possibly the 0th can be inferred
(i.e., not None) at graph construction time.
Some other notes:
* Following tensorflow conventions, we use height, width ordering,
and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
* Tensors are always provided as (flat) [N, 4] tensors.
"""
import tensorflow as tf
def _get_dim_as_int(dim):
"""Utility to get v1 or v2 TensorShape dim as an int.
Args:
dim: The TensorShape dimension to get as an int
Returns:
None or an int.
"""
try:
return dim.value
except AttributeError:
return dim
class BoxList(object):
"""Box collection."""
def __init__(self, boxes):
"""Constructs box collection.
Args:
boxes: a tensor of shape [N, 4] representing box corners
Raises:
ValueError: if invalid dimensions for bbox data or if bbox data is not in
float32 format.
"""
if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
raise ValueError('Invalid dimensions for box data: {}'.format(
boxes.shape))
if boxes.dtype != tf.float32:
raise ValueError('Invalid tensor type: should be tf.float32')
self.data = {'boxes': boxes}
def num_boxes(self):
"""Returns number of boxes held in collection.
Returns:
a tensor representing the number of boxes held in the collection.
"""
return tf.shape(self.data['boxes'])[0]
def num_boxes_static(self):
"""Returns number of boxes held in collection.
This number is inferred at graph construction time rather than run-time.
Returns:
Number of boxes held in collection (integer) or None if this is not
inferrable at graph construction time.
"""
return _get_dim_as_int(self.data['boxes'].get_shape()[0])
def get_all_fields(self):
"""Returns all fields."""
return self.data.keys()
def get_extra_fields(self):
"""Returns all non-box fields (i.e., everything not named 'boxes')."""
return [k for k in self.data.keys() if k != 'boxes']
def add_field(self, field, field_data):
"""Add field to box list.
This method can be used to add related box data such as
weights/labels, etc.
Args:
field: a string key to access the data via `get`
field_data: a tensor containing the data to store in the BoxList
"""
self.data[field] = field_data
def has_field(self, field):
return field in self.data
def get(self):
"""Convenience function for accessing box coordinates.
Returns:
a tensor with shape [N, 4] representing box coordinates.
"""
return self.get_field('boxes')
def set(self, boxes):
"""Convenience function for setting box coordinates.
Args:
boxes: a tensor of shape [N, 4] representing box corners
Raises:
ValueError: if invalid dimensions for bbox data
"""
if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
raise ValueError('Invalid dimensions for box data.')
self.data['boxes'] = boxes
def get_field(self, field):
"""Accesses a box collection and associated fields.
This function returns specified field with object; if no field is specified,
it returns the box coordinates.
Args:
field: this optional string parameter can be used to specify
a related field to be accessed.
Returns:
a tensor representing the box collection or an associated field.
Raises:
ValueError: if invalid field
"""
if not self.has_field(field):
raise ValueError('field ' + str(field) + ' does not exist')
return self.data[field]
def set_field(self, field, value):
"""Sets the value of a field.
Updates the field of a box_list with a given value.
Args:
field: (string) name of the field to set value.
value: the value to assign to the field.
Raises:
ValueError: if the box_list does not have specified field.
"""
if not self.has_field(field):
raise ValueError('field %s does not exist' % field)
self.data[field] = value
def get_center_coordinates_and_sizes(self):
"""Computes the center coordinates, height and width of the boxes.
Returns:
a list of 4 1-D tensors [ycenter, xcenter, height, width].
"""
with tf.name_scope('get_center_coordinates_and_sizes'):
box_corners = self.get()
ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners))
width = xmax - xmin
height = ymax - ymin
ycenter = ymin + height / 2.
xcenter = xmin + width / 2.
return [ycenter, xcenter, height, width]
def transpose_coordinates(self):
"""Transpose the coordinate representation in a boxlist."""
with tf.name_scope('transpose_coordinates'):
y_min, x_min, y_max, x_max = tf.split(
value=self.get(), num_or_size_splits=4, axis=1)
self.set(tf.concat([x_min, y_min, x_max, y_max], 1))
def as_tensor_dict(self, fields=None):
"""Retrieves specified fields as a dictionary of tensors.
Args:
fields: (optional) list of fields to return in the dictionary.
If None (default), all fields are returned.
Returns:
tensor_dict: A dictionary of tensors specified by fields.
Raises:
ValueError: if specified field is not contained in boxlist.
"""
tensor_dict = {}
if fields is None:
fields = self.get_all_fields()
for field in fields:
if not self.has_field(field):
raise ValueError('boxlist must contain all specified fields')
tensor_dict[field] = self.get_field(field)
return tensor_dict
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bounding Box List operations."""
import tensorflow as tf
from official.vision.beta.ops import sampling_ops
from official.vision.beta.projects.centernet.ops import box_list
def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
"""Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
Args:
boxlist_to_copy_to: BoxList to which extra fields are copied.
boxlist_to_copy_from: BoxList from which fields are copied.
Returns:
boxlist_to_copy_to with extra fields.
"""
for field in boxlist_to_copy_from.get_extra_fields():
boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
return boxlist_to_copy_to
def scale(boxlist, y_scale, x_scale):
"""scale box coordinates in x and y dimensions.
Args:
boxlist: BoxList holding N boxes
y_scale: (float) scalar tensor
x_scale: (float) scalar tensor
Returns:
boxlist: BoxList holding N boxes
"""
with tf.name_scope('Scale'):
y_scale = tf.cast(y_scale, tf.float32)
x_scale = tf.cast(x_scale, tf.float32)
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
y_min = y_scale * y_min
y_max = y_scale * y_max
x_min = x_scale * x_min
x_max = x_scale * x_max
scaled_boxlist = box_list.BoxList(
tf.concat([y_min, x_min, y_max, x_max], 1))
return _copy_extra_fields(scaled_boxlist, boxlist)
def area(boxlist):
"""Computes area of boxes.
Args:
boxlist: BoxList holding N boxes
Returns:
a tensor with shape [N] representing box areas.
"""
with tf.name_scope('Area'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
def change_coordinate_frame(boxlist, window):
"""Change coordinate frame of the boxlist to be relative to window's frame.
Given a window of the form [ymin, xmin, ymax, xmax],
changes bounding box coordinates from boxlist to be relative to this window
(e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
An example use case is data augmentation: where we are given groundtruth
boxes (boxlist) and would like to randomly crop the image to some
window (window). In this case we need to change the coordinate frame of
each groundtruth box to be relative to this new window.
Args:
boxlist: A BoxList object holding N boxes.
window: A rank 1 tensor [4].
Returns:
Returns a BoxList object with N boxes.
"""
with tf.name_scope('ChangeCoordinateFrame'):
win_height = window[2] - window[0]
win_width = window[3] - window[1]
boxlist_new = scale(box_list.BoxList(
boxlist.get() - [window[0], window[1], window[0], window[1]]),
1.0 / win_height, 1.0 / win_width)
boxlist_new = _copy_extra_fields(boxlist_new, boxlist)
return boxlist_new
def matmul_gather_on_zeroth_axis(params, indices):
"""Matrix multiplication based implementation of tf.gather on zeroth axis.
Args:
params: A float32 Tensor. The tensor from which to gather values.
Must be at least rank 1.
indices: A Tensor. Must be one of the following types: int32, int64.
Must be in range [0, params.shape[0])
Returns:
A Tensor. Has the same type as params. Values from params gathered
from indices given by indices, with shape indices.shape + params.shape[1:].
"""
with tf.name_scope('MatMulGather'):
params_shape = sampling_ops.combined_static_and_dynamic_shape(params)
indices_shape = sampling_ops.combined_static_and_dynamic_shape(indices)
params2d = tf.reshape(params, [params_shape[0], -1])
indicator_matrix = tf.one_hot(indices, params_shape[0])
gathered_result_flattened = tf.matmul(indicator_matrix, params2d)
return tf.reshape(gathered_result_flattened,
tf.stack(indices_shape + params_shape[1:]))
def gather(boxlist, indices, fields=None, use_static_shapes=False):
"""Gather boxes from BoxList according to indices and return new BoxList.
By default, `gather` returns boxes corresponding to the input index list, as
well as all additional fields stored in the boxlist (indexing into the
first dimension). However one can optionally only gather from a
subset of fields.
Args:
boxlist: BoxList holding N boxes
indices: a rank-1 tensor of type int32 / int64
fields: (optional) list of fields to also gather from. If None (default),
all fields are gathered from. Pass an empty fields list to only gather
the box coordinates.
use_static_shapes: Whether to use an implementation with static shape
gurantees.
Returns:
subboxlist: a BoxList corresponding to the subset of the input BoxList
specified by indices
Raises:
ValueError: if specified field is not contained in boxlist or if the
indices are not of type int32
"""
with tf.name_scope('Gather'):
if len(indices.shape.as_list()) != 1:
raise ValueError('indices should have rank 1')
if indices.dtype != tf.int32 and indices.dtype != tf.int64:
raise ValueError('indices should be an int32 / int64 tensor')
gather_op = tf.gather
if use_static_shapes:
gather_op = matmul_gather_on_zeroth_axis
subboxlist = box_list.BoxList(gather_op(boxlist.get(), indices))
if fields is None:
fields = boxlist.get_extra_fields()
fields += ['boxes']
for field in fields:
if not boxlist.has_field(field):
raise ValueError('boxlist must contain all specified fields')
subfieldlist = gather_op(boxlist.get_field(field), indices)
subboxlist.add_field(field, subfieldlist)
return subboxlist
def prune_completely_outside_window(boxlist, window):
"""Prunes bounding boxes that fall completely outside of the given window.
The function clip_to_window prunes bounding boxes that fall
completely outside the window, but also clips any bounding boxes that
partially overflow. This function does not clip partially overflowing boxes.
Args:
boxlist: a BoxList holding M_in boxes.
window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
of the window
Returns:
pruned_boxlist: a new BoxList with all bounding boxes partially or fully in
the window.
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
in the input tensor.
"""
with tf.name_scope('PruneCompleteleyOutsideWindow'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
coordinate_violations = tf.concat([
tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max),
tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min)
], 1)
valid_indices = tf.reshape(
tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
return gather(boxlist, valid_indices), valid_indices
def clip_to_window(boxlist, window, filter_nonoverlapping=True):
"""Clip bounding boxes to a window.
This op clips any input bounding boxes (represented by bounding box
corners) to a window, optionally filtering out boxes that do not
overlap at all with the window.
Args:
boxlist: BoxList holding M_in boxes
window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
window to which the op should clip boxes.
filter_nonoverlapping: whether to filter out boxes that do not overlap at
all with the window.
Returns:
a BoxList holding M_out boxes where M_out <= M_in
"""
with tf.name_scope('ClipToWindow'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
win_y_min = window[0]
win_x_min = window[1]
win_y_max = window[2]
win_x_max = window[3]
y_min_clipped = tf.maximum(tf.minimum(y_min, win_y_max), win_y_min)
y_max_clipped = tf.maximum(tf.minimum(y_max, win_y_max), win_y_min)
x_min_clipped = tf.maximum(tf.minimum(x_min, win_x_max), win_x_min)
x_max_clipped = tf.maximum(tf.minimum(x_max, win_x_max), win_x_min)
clipped = box_list.BoxList(
tf.concat([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped],
1))
clipped = _copy_extra_fields(clipped, boxlist)
if filter_nonoverlapping:
areas = area(clipped)
nonzero_area_indices = tf.cast(
tf.reshape(tf.where(tf.greater(areas, 0.0)), [-1]), tf.int32)
clipped = gather(clipped, nonzero_area_indices)
return clipped
def height_width(boxlist):
"""Computes height and width of boxes in boxlist.
Args:
boxlist: BoxList holding N boxes
Returns:
Height: A tensor with shape [N] representing box heights.
Width: A tensor with shape [N] representing box widths.
"""
with tf.name_scope('HeightWidth'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
return tf.squeeze(y_max - y_min, [1]), tf.squeeze(x_max - x_min, [1])
def prune_small_boxes(boxlist, min_side):
"""Prunes small boxes in the boxlist which have a side smaller than min_side.
Args:
boxlist: BoxList holding N boxes.
min_side: Minimum width AND height of box to survive pruning.
Returns:
A pruned boxlist.
"""
with tf.name_scope('PruneSmallBoxes'):
height, width = height_width(boxlist)
is_valid = tf.logical_and(tf.greater_equal(width, min_side),
tf.greater_equal(height, min_side))
return gather(boxlist, tf.reshape(tf.where(is_valid), [-1]))
def assert_or_prune_invalid_boxes(boxes):
"""Makes sure boxes have valid sizes (ymax >= ymin, xmax >= xmin).
When the hardware supports assertions, the function raises an error when
boxes have an invalid size. If assertions are not supported (e.g. on TPU),
boxes with invalid sizes are filtered out.
Args:
boxes: float tensor of shape [num_boxes, 4]
Returns:
boxes: float tensor of shape [num_valid_boxes, 4] with invalid boxes
filtered out.
Raises:
tf.errors.InvalidArgumentError: When we detect boxes with invalid size.
This is not supported on TPUs.
"""
ymin, xmin, ymax, xmax = tf.split(
boxes, num_or_size_splits=4, axis=1)
height_check = tf.Assert(tf.reduce_all(ymax >= ymin), [ymin, ymax])
width_check = tf.Assert(tf.reduce_all(xmax >= xmin), [xmin, xmax])
with tf.control_dependencies([height_check, width_check]):
boxes_tensor = tf.concat([ymin, xmin, ymax, xmax], axis=1)
boxlist = box_list.BoxList(boxes_tensor)
boxlist = prune_small_boxes(boxlist, 0)
return boxlist.get()
def to_absolute_coordinates(boxlist,
height,
width,
check_range=True,
maximum_normalized_coordinate=1.1):
"""Converts normalized box coordinates to absolute pixel coordinates.
This function raises an assertion failed error when the maximum box coordinate
value is larger than maximum_normalized_coordinate (in which case coordinates
are already absolute).
Args:
boxlist: BoxList with coordinates in range [0, 1].
height: Maximum value for height of absolute box coordinates.
width: Maximum value for width of absolute box coordinates.
check_range: If True, checks if the coordinates are normalized or not.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1.
Returns:
boxlist with absolute coordinates in terms of the image size.
"""
with tf.name_scope('ToAbsoluteCoordinates'):
height = tf.cast(height, tf.float32)
width = tf.cast(width, tf.float32)
# Ensure range of input boxes is correct.
if check_range:
box_maximum = tf.reduce_max(boxlist.get())
max_assert = tf.Assert(
tf.greater_equal(maximum_normalized_coordinate, box_maximum),
['maximum box coordinate value is larger '
'than %f: ' % maximum_normalized_coordinate, box_maximum])
with tf.control_dependencies([max_assert]):
width = tf.identity(width)
return scale(boxlist, height, width)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operations for compute losses for centernet."""
import tensorflow as tf
from official.vision.beta.ops import sampling_ops
def _get_shape(tensor, num_dims):
assert len(tensor.shape.as_list()) == num_dims
return sampling_ops.combined_static_and_dynamic_shape(tensor)
def flatten_spatial_dimensions(batch_images):
# pylint: disable=unbalanced-tuple-unpacking
batch_size, height, width, channels = _get_shape(batch_images, 4)
return tf.reshape(batch_images, [batch_size, height * width,
channels])
def multi_range(limit,
value_repetitions=1,
range_repetitions=1,
dtype=tf.int32):
"""Creates a sequence with optional value duplication and range repetition.
As an example (see the Args section for more details),
_multi_range(limit=2, value_repetitions=3, range_repetitions=4) returns:
[0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
NOTE: Repurposed from Google OD API.
Args:
limit: A 0-D Tensor (scalar). Upper limit of sequence, exclusive.
value_repetitions: Integer. The number of times a value in the sequence is
repeated. With value_repetitions=3, the result is [0, 0, 0, 1, 1, 1, ..].
range_repetitions: Integer. The number of times the range is repeated. With
range_repetitions=3, the result is [0, 1, 2, .., 0, 1, 2, ..].
dtype: The type of the elements of the resulting tensor.
Returns:
A 1-D tensor of type `dtype` and size
[`limit` * `value_repetitions` * `range_repetitions`] that contains the
specified range with given repetitions.
"""
return tf.reshape(
tf.tile(
tf.expand_dims(tf.range(limit, dtype=dtype), axis=-1),
multiples=[range_repetitions, value_repetitions]), [-1])
def add_batch_to_indices(indices):
shape = tf.shape(indices)
batch_size = shape[0]
num_instances = shape[1]
batch_range = multi_range(limit=batch_size, value_repetitions=num_instances)
batch_range = tf.reshape(batch_range, shape=(batch_size, num_instances, 1))
return tf.concat([batch_range, indices], axis=-1)
def get_num_instances_from_weights(gt_weights_list):
"""Computes the number of instances/boxes from the weights in a batch.
Args:
gt_weights_list: A list of float tensors with shape
[max_num_instances] representing whether there is an actual instance in
the image (with non-zero value) or is padded to match the
max_num_instances (with value 0.0). The list represents the batch
dimension.
Returns:
A scalar integer tensor incidating how many instances/boxes are in the
images in the batch. Note that this function is usually used to normalize
the loss so the minimum return value is 1 to avoid weird behavior.
"""
# This can execute in graph mode
gt_weights_list = tf.convert_to_tensor(
gt_weights_list, dtype=gt_weights_list[0].dtype)
num_instances = tf.map_fn(
fn=lambda x: tf.math.count_nonzero(x, dtype=gt_weights_list[0].dtype),
elems=gt_weights_list)
num_instances = tf.reduce_sum(num_instances)
num_instances = tf.maximum(num_instances, 1)
return num_instances
def get_batch_predictions_from_indices(batch_predictions, indices):
"""Gets the values of predictions in a batch at the given indices.
The indices are expected to come from the offset targets generation functions
in this library. The returned value is intended to be used inside a loss
function.
Args:
batch_predictions: A tensor of shape [batch_size, height, width, channels]
or [batch_size, height, width, class, channels] for class-specific
features (e.g. keypoint joint offsets).
indices: A tensor of shape [num_instances, 3] for single class features or
[num_instances, 4] for multiple classes features.
Returns:
values: A tensor of shape [num_instances, channels] holding the predicted
values at the given indices.
"""
return tf.gather_nd(batch_predictions, indices)
def get_valid_anchor_weights_in_flattened_image(true_image_shapes, height,
width):
"""Computes valid anchor weights for an image assuming pixels to be flattened.
This function is useful when we only want to penalize valid areas in the
image in the case when padding is used. The function assumes that the loss
function will be applied after flattening the spatial dimensions and returns
anchor weights accordingly.
Args:
true_image_shapes: An integer tensor of shape [batch_size, 3] representing
the true image shape (without padding) for each sample in the batch.
height: height of the prediction from the network.
width: width of the prediction from the network.
Returns:
valid_anchor_weights: a float tensor of shape [batch_size, height * width]
with 1s in locations where the spatial coordinates fall within the height
and width in true_image_shapes.
"""
indices = tf.reshape(tf.range(height * width), [1, -1])
batch_size = tf.shape(true_image_shapes)[0]
batch_indices = tf.ones((batch_size, 1), dtype=tf.int32) * indices
y_coords, x_coords, _ = get_row_col_channel_indices_from_flattened_indices(
batch_indices, width, 1)
max_y, max_x = true_image_shapes[:, 0], true_image_shapes[:, 1]
max_x = tf.cast(tf.expand_dims(max_x, 1), tf.float32)
max_y = tf.cast(tf.expand_dims(max_y, 1), tf.float32)
x_coords = tf.cast(x_coords, tf.float32)
y_coords = tf.cast(y_coords, tf.float32)
valid_mask = tf.math.logical_and(x_coords < max_x, y_coords < max_y)
return tf.cast(valid_mask, tf.float32)
def get_row_col_channel_indices_from_flattened_indices(indices: int,
num_cols: int,
num_channels: int):
"""Computes row, column and channel indices from flattened indices.
NOTE: Repurposed from Google OD API.
Args:
indices: An `int` tensor of any shape holding the indices in the flattened
space.
num_cols: `int`, number of columns in the image (width).
num_channels: `int`, number of channels in the image.
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
channel_indices. The channel indices corresponding to each of the input
indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
# all inputs and outputs are dtype int32
row_indices = (indices // num_channels) // num_cols
col_indices = (indices // num_channels) - row_indices * num_cols
channel_indices_temp = indices // num_channels
channel_indices = indices - channel_indices_temp * num_channels
return row_indices, col_indices, channel_indices
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""nms computation."""
import tensorflow as tf
from official.vision.beta.projects.yolo.ops import box_ops
NMS_TILE_SIZE = 512
# pylint: disable=missing-function-docstring
def aggregated_comparative_iou(boxes1, boxes2=None, iou_type=0):
k = tf.shape(boxes1)[-2]
boxes1 = tf.expand_dims(boxes1, axis=-2)
boxes1 = tf.tile(boxes1, [1, 1, k, 1])
if boxes2 is not None:
boxes2 = tf.expand_dims(boxes2, axis=-2)
boxes2 = tf.tile(boxes2, [1, 1, k, 1])
boxes2 = tf.transpose(boxes2, perm=(0, 2, 1, 3))
else:
boxes2 = tf.transpose(boxes1, perm=(0, 2, 1, 3))
if iou_type == 0: # diou
_, iou = box_ops.compute_diou(boxes1, boxes2)
elif iou_type == 1: # giou
_, iou = box_ops.compute_giou(boxes1, boxes2)
else:
iou = box_ops.compute_iou(boxes1, boxes2, yxyx=True)
return iou
# pylint: disable=missing-function-docstring
def sort_drop(objectness, box, classificationsi, k):
objectness, ind = tf.math.top_k(objectness, k=k)
ind_m = tf.ones_like(ind) * tf.expand_dims(
tf.range(0,
tf.shape(objectness)[0]), axis=-1)
bind = tf.stack([tf.reshape(ind_m, [-1]), tf.reshape(ind, [-1])], axis=-1)
box = tf.gather_nd(box, bind)
classifications = tf.gather_nd(classificationsi, bind)
bsize = tf.shape(ind)[0]
box = tf.reshape(box, [bsize, k, -1])
classifications = tf.reshape(classifications, [bsize, k, -1])
return objectness, box, classifications
# pylint: disable=missing-function-docstring
def segment_nms(boxes, classes, confidence, k, iou_thresh):
mrange = tf.range(k)
mask_x = tf.tile(
tf.transpose(tf.expand_dims(mrange, axis=-1), perm=[1, 0]), [k, 1])
mask_y = tf.tile(tf.expand_dims(mrange, axis=-1), [1, k])
mask_diag = tf.expand_dims(mask_x > mask_y, axis=0)
iou = aggregated_comparative_iou(boxes, iou_type=0)
# duplicate boxes
iou_mask = iou >= iou_thresh
iou_mask = tf.logical_and(mask_diag, iou_mask)
iou *= tf.cast(iou_mask, iou.dtype)
can_suppress_others = 1 - tf.cast(
tf.reduce_any(iou_mask, axis=-2), boxes.dtype)
raw = tf.cast(can_suppress_others, boxes.dtype)
boxes *= tf.expand_dims(raw, axis=-1)
confidence *= tf.cast(raw, confidence.dtype)
classes *= tf.cast(raw, classes.dtype)
return boxes, classes, confidence
# pylint: disable=missing-function-docstring
def nms(boxes,
classes,
confidence,
k,
pre_nms_thresh,
nms_thresh,
limit_pre_thresh=False,
use_classes=True):
if limit_pre_thresh:
confidence, boxes, classes = sort_drop(confidence, boxes, classes, k)
mask = tf.fill(
tf.shape(confidence), tf.cast(pre_nms_thresh, dtype=confidence.dtype))
mask = tf.math.ceil(tf.nn.relu(confidence - mask))
confidence = confidence * mask
mask = tf.expand_dims(mask, axis=-1)
boxes = boxes * mask
classes = classes * mask
if use_classes:
confidence = tf.reduce_max(classes, axis=-1)
confidence, boxes, classes = sort_drop(confidence, boxes, classes, k)
classes = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
boxes, classes, confidence = segment_nms(boxes, classes, confidence, k,
nms_thresh)
confidence, boxes, classes = sort_drop(confidence, boxes, classes, k)
classes = tf.squeeze(classes, axis=-1)
return boxes, classes, confidence
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing ops imported from OD API."""
import functools
import tensorflow as tf
from official.vision.beta.projects.centernet.ops import box_list
from official.vision.beta.projects.centernet.ops import box_list_ops
def _get_or_create_preprocess_rand_vars(generator_func,
function_id,
preprocess_vars_cache,
key=''):
"""Returns a tensor stored in preprocess_vars_cache or using generator_func.
If the tensor was previously generated and appears in the PreprocessorCache,
the previously generated tensor will be returned. Otherwise, a new tensor
is generated using generator_func and stored in the cache.
Args:
generator_func: A 0-argument function that generates a tensor.
function_id: identifier for the preprocessing function used.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
key: identifier for the variable stored.
Returns:
The generated tensor.
"""
if preprocess_vars_cache is not None:
var = preprocess_vars_cache.get(function_id, key)
if var is None:
var = generator_func()
preprocess_vars_cache.update(function_id, key, var)
else:
var = generator_func()
return var
def _random_integer(minval, maxval, seed):
"""Returns a random 0-D tensor between minval and maxval.
Args:
minval: minimum value of the random tensor.
maxval: maximum value of the random tensor.
seed: random seed.
Returns:
A random 0-D tensor between minval and maxval.
"""
return tf.random.uniform(
[], minval=minval, maxval=maxval, dtype=tf.int32, seed=seed)
def _get_crop_border(border, size):
"""Get the border of cropping."""
border = tf.cast(border, tf.float32)
size = tf.cast(size, tf.float32)
i = tf.math.ceil(tf.math.log(2.0 * border / size) / tf.math.log(2.0))
divisor = tf.pow(2.0, i)
divisor = tf.clip_by_value(divisor, 1, border)
divisor = tf.cast(divisor, tf.int32)
return tf.cast(border, tf.int32) // divisor
def random_square_crop_by_scale(image,
boxes,
labels,
max_border=128,
scale_min=0.6,
scale_max=1.3,
num_scales=8,
seed=None,
preprocess_vars_cache=None):
"""Randomly crop a square in proportion to scale and image size.
Extract a square sized crop from an image whose side length is sampled by
randomly scaling the maximum spatial dimension of the image. If part of
the crop falls outside the image, it is filled with zeros.
The augmentation is borrowed from [1]
[1]: https://arxiv.org/abs/1904.07850
Args:
image: rank 3 float32 tensor containing 1 image ->
[height, width, channels].
boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
Boxes are in normalized form meaning their coordinates vary
between [0, 1]. Each row is in the form of [ymin, xmin, ymax, xmax].
Boxes on the crop boundary are clipped to the boundary and boxes
falling outside the crop are ignored.
labels: rank 1 int32 tensor containing the object classes.
max_border: The maximum size of the border. The border defines distance in
pixels to the image boundaries that will not be considered as a center of
a crop. To make sure that the border does not go over the center of the
image, we chose the border value by computing the minimum k, such that
(max_border / (2**k)) < image_dimension/2.
scale_min: float, the minimum value for scale.
scale_max: float, the maximum value for scale.
num_scales: int, the number of discrete scale values to sample between
[scale_min, scale_max]
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same rank as input image.
boxes: boxes which is the same rank as input boxes.
Boxes are in normalized form.
labels: new labels.
"""
img_shape = tf.shape(image)
height, width = img_shape[0], img_shape[1]
scales = tf.linspace(scale_min, scale_max, num_scales)
scale = _get_or_create_preprocess_rand_vars(
lambda: scales[_random_integer(0, num_scales, seed)],
'square_crop_scale',
preprocess_vars_cache, 'scale')
image_size = scale * tf.cast(tf.maximum(height, width), tf.float32)
image_size = tf.cast(image_size, tf.int32)
h_border = _get_crop_border(max_border, height)
w_border = _get_crop_border(max_border, width)
def y_function():
y = _random_integer(h_border,
tf.cast(height, tf.int32) - h_border + 1,
seed)
return y
def x_function():
x = _random_integer(w_border,
tf.cast(width, tf.int32) - w_border + 1,
seed)
return x
y_center = _get_or_create_preprocess_rand_vars(
y_function,
'square_crop_scale',
preprocess_vars_cache, 'y_center')
x_center = _get_or_create_preprocess_rand_vars(
x_function,
'square_crop_scale',
preprocess_vars_cache, 'x_center')
half_size = tf.cast(image_size / 2, tf.int32)
crop_ymin, crop_ymax = y_center - half_size, y_center + half_size
crop_xmin, crop_xmax = x_center - half_size, x_center + half_size
ymin = tf.maximum(crop_ymin, 0)
xmin = tf.maximum(crop_xmin, 0)
ymax = tf.minimum(crop_ymax, height - 1)
xmax = tf.minimum(crop_xmax, width - 1)
cropped_image = image[ymin:ymax, xmin:xmax]
offset_y = tf.maximum(0, ymin - crop_ymin)
offset_x = tf.maximum(0, xmin - crop_xmin)
oy_i = offset_y
ox_i = offset_x
output_image = tf.image.pad_to_bounding_box(
cropped_image, offset_height=oy_i, offset_width=ox_i,
target_height=image_size, target_width=image_size)
if ymin == 0:
# We might be padding the image.
box_ymin = -offset_y
else:
box_ymin = crop_ymin
if xmin == 0:
# We might be padding the image.
box_xmin = -offset_x
else:
box_xmin = crop_xmin
box_ymax = box_ymin + image_size
box_xmax = box_xmin + image_size
image_box = [box_ymin / height, box_xmin / width,
box_ymax / height, box_xmax / width]
boxlist = box_list.BoxList(boxes)
boxlist = box_list_ops.change_coordinate_frame(boxlist, image_box)
boxlist, indices = box_list_ops.prune_completely_outside_window(
boxlist, [0.0, 0.0, 1.0, 1.0])
boxlist = box_list_ops.clip_to_window(boxlist, [0.0, 0.0, 1.0, 1.0],
filter_nonoverlapping=False)
return_values = [output_image,
boxlist.get(),
tf.gather(labels, indices)]
return return_values
def resize_to_range(image,
masks=None,
min_dimension=None,
max_dimension=None,
method=tf.image.ResizeMethod.BILINEAR,
pad_to_max_dimension=False,
per_channel_pad_value=(0, 0, 0)):
"""Resizes an image so its dimensions are within the provided value.
The output size can be described by two cases:
1. If the image can be rescaled so its minimum dimension is equal to the
provided value without the other dimension exceeding max_dimension,
then do so.
2. Otherwise, resize so the largest dimension is equal to max_dimension.
Args:
image: A 3D tensor of shape [height, width, channels]
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks.
min_dimension: (optional) (scalar) desired size of the smaller image
dimension.
max_dimension: (optional) (scalar) maximum allowed size
of the larger image dimension.
method: (optional) interpolation method used in resizing. Defaults to
BILINEAR.
pad_to_max_dimension: Whether to resize the image and pad it with zeros
so the resulting image is of the spatial size
[max_dimension, max_dimension]. If masks are included they are padded
similarly.
per_channel_pad_value: A tuple of per-channel scalar value to use for
padding. By default pads zeros.
Returns:
Note that the position of the resized_image_shape changes based on whether
masks are present.
resized_image: A 3D tensor of shape [new_height, new_width, channels],
where the image has been resized (with bilinear interpolation) so that
min(new_height, new_width) == min_dimension or
max(new_height, new_width) == max_dimension.
resized_masks: If masks is not None, also outputs masks. A 3D tensor of
shape [num_instances, new_height, new_width].
resized_image_shape: A 1D tensor of shape [3] containing shape of the
resized image.
Raises:
ValueError: if the image is not a 3D tensor.
"""
if len(image.get_shape()) != 3:
raise ValueError('Image should be 3D tensor')
def _resize_landscape_image(image):
# resize a landscape image
return tf.image.resize(
image, tf.stack([min_dimension, max_dimension]), method=method,
preserve_aspect_ratio=True)
def _resize_portrait_image(image):
# resize a portrait image
return tf.image.resize(
image, tf.stack([max_dimension, min_dimension]), method=method,
preserve_aspect_ratio=True)
with tf.name_scope('ResizeToRange'):
if image.get_shape().is_fully_defined():
if image.get_shape()[0] < image.get_shape()[1]:
new_image = _resize_landscape_image(image)
else:
new_image = _resize_portrait_image(image)
new_size = tf.constant(new_image.get_shape().as_list())
else:
new_image = tf.cond(
tf.less(tf.shape(image)[0], tf.shape(image)[1]),
lambda: _resize_landscape_image(image),
lambda: _resize_portrait_image(image))
new_size = tf.shape(new_image)
if pad_to_max_dimension:
channels = tf.unstack(new_image, axis=2)
if len(channels) != len(per_channel_pad_value):
raise ValueError('Number of channels must be equal to the length of '
'per-channel pad value.')
new_image = tf.stack(
[
tf.pad( # pylint: disable=g-complex-comprehension
channels[i], [[0, max_dimension - new_size[0]],
[0, max_dimension - new_size[1]]],
constant_values=per_channel_pad_value[i])
for i in range(len(channels))
],
axis=2)
new_image.set_shape([max_dimension, max_dimension, len(channels)])
result = [new_image, new_size]
if masks is not None:
new_masks = tf.expand_dims(masks, 3)
new_masks = tf.image.resize(
new_masks,
new_size[:-1],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if pad_to_max_dimension:
new_masks = tf.image.pad_to_bounding_box(
new_masks, 0, 0, max_dimension, max_dimension)
new_masks = tf.squeeze(new_masks, 3)
result.append(new_masks)
return result
def _augment_only_rgb_channels(image, augment_function):
"""Augments only the RGB slice of an image with additional channels."""
rgb_slice = image[:, :, :3]
augmented_rgb_slice = augment_function(rgb_slice)
image = tf.concat([augmented_rgb_slice, image[:, :, 3:]], -1)
return image
def random_adjust_brightness(image,
max_delta=0.2,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts brightness.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
max_delta: how much to change the brightness. A value between [0, 1).
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustBrightness'):
generator_func = functools.partial(tf.random.uniform, [],
-max_delta, max_delta, seed=seed)
delta = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_brightness',
preprocess_vars_cache)
def _adjust_brightness(image):
image = tf.image.adjust_brightness(image / 255, delta) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_brightness)
return image
def random_adjust_contrast(image,
min_delta=0.8,
max_delta=1.25,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts contrast.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
min_delta: see max_delta.
max_delta: how much to change the contrast. Contrast will change with a
value between min_delta and max_delta. This value will be
multiplied to the current contrast of the image.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustContrast'):
generator_func = functools.partial(tf.random.uniform, [],
min_delta, max_delta, seed=seed)
contrast_factor = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_contrast',
preprocess_vars_cache)
def _adjust_contrast(image):
image = tf.image.adjust_contrast(image / 255, contrast_factor) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_contrast)
return image
def random_adjust_hue(image,
max_delta=0.02,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts hue.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
max_delta: change hue randomly with a value between 0 and max_delta.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustHue'):
generator_func = functools.partial(tf.random.uniform, [],
-max_delta, max_delta, seed=seed)
delta = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_hue',
preprocess_vars_cache)
def _adjust_hue(image):
image = tf.image.adjust_hue(image / 255, delta) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_hue)
return image
def random_adjust_saturation(image,
min_delta=0.8,
max_delta=1.25,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts saturation.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
min_delta: see max_delta.
max_delta: how much to change the saturation. Saturation will change with a
value between min_delta and max_delta. This value will be
multiplied to the current saturation of the image.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustSaturation'):
generator_func = functools.partial(tf.random.uniform, [],
min_delta, max_delta, seed=seed)
saturation_factor = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_saturation',
preprocess_vars_cache)
def _adjust_saturation(image):
image = tf.image.adjust_saturation(image / 255, saturation_factor) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_saturation)
return image
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate targets (center, scale, offsets,...) for centernet."""
from typing import Dict, List
import tensorflow as tf
from official.vision.beta.ops import sampling_ops
def smallest_positive_root(a, b, c):
"""Returns the smallest positive root of a quadratic equation."""
discriminant = tf.sqrt(b ** 2 - 4 * a * c)
return (-b + discriminant) / (2.0)
@tf.function
def cartesian_product(*tensors, repeat: int = 1) -> tf.Tensor:
"""Equivalent of itertools.product except for TensorFlow tensors.
Example:
cartesian_product(tf.range(3), tf.range(4))
array([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[1, 0],
[1, 1],
[1, 2],
[1, 3],
[2, 0],
[2, 1],
[2, 2],
[2, 3]], dtype=int32)>
Args:
*tensors: a list of 1D tensors to compute the product of
repeat: an `int` number of times to repeat the tensors
Returns:
An nD tensor where n is the number of tensors
"""
tensors = tensors * repeat
return tf.reshape(tf.transpose(tf.stack(tf.meshgrid(*tensors, indexing='ij')),
[*[i + 1 for i in range(len(tensors))], 0]),
(-1, len(tensors)))
def image_shape_to_grids(height: int, width: int):
"""Computes xy-grids given the shape of the image.
Args:
height: The height of the image.
width: The width of the image.
Returns:
A tuple of two tensors:
y_grid: A float tensor with shape [height, width] representing the
y-coordinate of each pixel grid.
x_grid: A float tensor with shape [height, width] representing the
x-coordinate of each pixel grid.
"""
out_height = tf.cast(height, tf.float32)
out_width = tf.cast(width, tf.float32)
x_range = tf.range(out_width, dtype=tf.float32)
y_range = tf.range(out_height, dtype=tf.float32)
x_grid, y_grid = tf.meshgrid(x_range, y_range, indexing='xy')
return (y_grid, x_grid)
def max_distance_for_overlap(height, width, min_iou):
"""Computes how far apart bbox corners can lie while maintaining the iou.
Given a bounding box size, this function returns a lower bound on how far
apart the corners of another box can lie while still maintaining the given
IoU. The implementation is based on the `gaussian_radius` function in the
Objects as Points github repo: https://github.com/xingyizhou/CenterNet
Args:
height: A 1-D float Tensor representing height of the ground truth boxes.
width: A 1-D float Tensor representing width of the ground truth boxes.
min_iou: A float representing the minimum IoU desired.
Returns:
distance: A 1-D Tensor of distances, of the same length as the input
height and width tensors.
"""
# Given that the detected box is displaced at a distance `d`, the exact
# IoU value will depend on the angle at which each corner is displaced.
# We simplify our computation by assuming that each corner is displaced by
# a distance `d` in both x and y direction. This gives us a lower IoU than
# what is actually realizable and ensures that any box with corners less
# than `d` distance apart will always have an IoU greater than or equal
# to `min_iou`
# The following 3 cases can be worked on geometrically and come down to
# solving a quadratic inequality. In each case, to ensure `min_iou` we use
# the smallest positive root of the equation.
# Case where detected box is offset from ground truth and no box completely
# contains the other.
distance_detection_offset = smallest_positive_root(
a=1, b=-(height + width),
c=width * height * ((1 - min_iou) / (1 + min_iou))
)
# Case where detection is smaller than ground truth and completely contained
# in it.
distance_detection_in_gt = smallest_positive_root(
a=4, b=-2 * (height + width),
c=(1 - min_iou) * width * height
)
# Case where ground truth is smaller than detection and completely contained
# in it.
distance_gt_in_detection = smallest_positive_root(
a=4 * min_iou, b=(2 * min_iou) * (width + height),
c=(min_iou - 1) * width * height
)
return tf.reduce_min([distance_detection_offset,
distance_gt_in_detection,
distance_detection_in_gt], axis=0)
def compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap):
"""Computes the standard deviation of the Gaussian kernel from box size.
Args:
boxes_height: A 1D tensor with shape [num_instances] representing the height
of each box.
boxes_width: A 1D tensor with shape [num_instances] representing the width
of each box.
min_overlap: The minimum IOU overlap that boxes need to have to not be
penalized.
Returns:
A 1D tensor with shape [num_instances] representing the computed Gaussian
sigma for each of the box.
"""
# We are dividing by 3 so that points closer than the computed
# distance have a >99% CDF.
sigma = max_distance_for_overlap(boxes_height, boxes_width, min_overlap)
sigma = (2 * tf.math.maximum(tf.math.floor(sigma), 0.0) + 1) / 6.0
return sigma
@tf.function
def assign_center_targets(out_height: int,
out_width: int,
y_center: tf.Tensor,
x_center: tf.Tensor,
boxes_height: tf.Tensor,
boxes_width: tf.Tensor,
channel_onehot: tf.Tensor,
gaussian_iou: float):
"""Computes the object center heatmap target based on ODAPI implementation.
Args:
out_height: int, height of output to the model. This is used to
determine the height of the output.
out_width: int, width of the output to the model. This is used to
determine the width of the output.
y_center: A 1D tensor with shape [num_instances] representing the
y-coordinates of the instances in the output space coordinates.
x_center: A 1D tensor with shape [num_instances] representing the
x-coordinates of the instances in the output space coordinates.
boxes_height: A 1D tensor with shape [num_instances] representing the height
of each box.
boxes_width: A 1D tensor with shape [num_instances] representing the width
of each box.
channel_onehot: A 2D tensor with shape [num_instances, num_channels]
representing the one-hot encoded channel labels for each point.
gaussian_iou: The minimum IOU overlap that boxes need to have to not be
penalized.
Returns:
heatmap: A Tensor of size [output_height, output_width,
num_classes] representing the per class center heatmap. output_height
and output_width are computed by dividing the input height and width by
the stride specified during initialization.
"""
(y_grid, x_grid) = image_shape_to_grids(out_height, out_width)
sigma = compute_std_dev_from_box_size(boxes_height, boxes_width,
gaussian_iou)
num_instances, num_channels = (
sampling_ops.combined_static_and_dynamic_shape(channel_onehot))
x_grid = tf.expand_dims(x_grid, 2)
y_grid = tf.expand_dims(y_grid, 2)
# The raw center coordinates in the output space.
x_diff = x_grid - tf.math.floor(x_center)
y_diff = y_grid - tf.math.floor(y_center)
squared_distance = x_diff ** 2 + y_diff ** 2
gaussian_map = tf.exp(-squared_distance / (2 * sigma * sigma))
reshaped_gaussian_map = tf.expand_dims(gaussian_map, axis=-1)
reshaped_channel_onehot = tf.reshape(channel_onehot,
(1, 1, num_instances, num_channels))
gaussian_per_box_per_class_map = (
reshaped_gaussian_map * reshaped_channel_onehot)
# Take maximum along the "instance" dimension so that all per-instance
# heatmaps of the same class are merged together.
heatmap = tf.reduce_max(gaussian_per_box_per_class_map, axis=2)
# Maximum of an empty tensor is -inf, the following is to avoid that.
heatmap = tf.maximum(heatmap, 0)
return tf.stop_gradient(heatmap)
def assign_centernet_targets(labels: Dict[str, tf.Tensor],
output_size: List[int],
input_size: List[int],
num_classes: int = 90,
max_num_instances: int = 128,
gaussian_iou: float = 0.7,
class_offset: int = 0,
dtype='float32'):
"""Generates the ground truth labels for centernet.
Ground truth labels are generated by splatting gaussians on heatmaps for
corners and centers. Regressed features (offsets and sizes) are also
generated.
Args:
labels: A dictionary of COCO ground truth labels with at minimum the
following fields:
"bbox" A `Tensor` of shape [max_num_instances, 4], where the
last dimension corresponds to the top left x, top left y,
bottom right x, and bottom left y coordinates of the bounding box
"classes" A `Tensor` of shape [max_num_instances] that contains
the class of each box, given in the same order as the boxes
"num_detections" A `Tensor` or int that gives the number of objects
output_size: A `list` of length 2 containing the desired output height
and width of the heatmaps
input_size: A `list` of length 2 the expected input height and width of
the image
num_classes: A `Tensor` or `int` for the number of classes.
max_num_instances: An `int` for maximum number of instances in an image.
gaussian_iou: A `float` number for the minimum desired IOU used when
determining the gaussian radius of center locations in the heatmap.
class_offset: A `int` for subtracting a value from the ground truth classes
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
Returns:
Dictionary of labels with the following fields:
'ct_heatmaps': Tensor of shape [output_h, output_w, num_classes],
heatmap with splatted gaussians centered at the positions and channels
corresponding to the center location and class of the object
'ct_offset': `Tensor` of shape [max_num_instances, 2], where the first
num_boxes entries contain the x-offset and y-offset of the center of
an object. All other entires are 0
'size': `Tensor` of shape [max_num_instances, 2], where the first
num_boxes entries contain the width and height of an object. All
other entires are 0
'box_mask': `Tensor` of shape [max_num_instances], where the first
num_boxes entries are 1. All other entires are 0
'box_indices': `Tensor` of shape [max_num_instances, 2], where the first
num_boxes entries contain the y-center and x-center of a valid box.
These are used to extract the regressed box features from the
prediction when computing the loss
Raises:
Exception: if datatype is not supported.
"""
if dtype == 'float16':
dtype = tf.float16
elif dtype == 'bfloat16':
dtype = tf.bfloat16
elif dtype == 'float32':
dtype = tf.float32
else:
raise Exception(
'Unsupported datatype used in ground truth builder only '
'{float16, bfloat16, or float32}')
# Get relevant bounding box and class information from labels
# only keep the first num_objects boxes and classes
num_objects = labels['groundtruths']['num_detections']
# shape of labels['boxes'] is [max_num_instances, 4]
# [ymin, xmin, ymax, xmax]
boxes = tf.cast(labels['boxes'], dtype)
# shape of labels['classes'] is [max_num_instances, ]
classes = tf.cast(labels['classes'] - class_offset, dtype)
# Compute scaling factors for center/corner positions on heatmap
# input_size = tf.cast(input_size, dtype)
# output_size = tf.cast(output_size, dtype)
input_h, input_w = input_size[0], input_size[1]
output_h, output_w = output_size[0], output_size[1]
width_ratio = output_w / input_w
height_ratio = output_h / input_h
# Original box coordinates
# [max_num_instances, ]
ytl, ybr = boxes[..., 0], boxes[..., 2]
xtl, xbr = boxes[..., 1], boxes[..., 3]
yct = (ytl + ybr) / 2
xct = (xtl + xbr) / 2
# Scaled box coordinates (could be floating point)
# [max_num_instances, ]
scale_xct = xct * width_ratio
scale_yct = yct * height_ratio
# Floor the scaled box coordinates to be placed on heatmaps
# [max_num_instances, ]
scale_xct_floor = tf.math.floor(scale_xct)
scale_yct_floor = tf.math.floor(scale_yct)
# Offset computations to make up for discretization error
# used for offset maps
# [max_num_instances, 2]
ct_offset_values = tf.stack([scale_yct - scale_yct_floor,
scale_xct - scale_xct_floor], axis=-1)
# Get the scaled box dimensions for computing the gaussian radius
# [max_num_instances, ]
box_widths = boxes[..., 3] - boxes[..., 1]
box_heights = boxes[..., 2] - boxes[..., 0]
box_widths = box_widths * width_ratio
box_heights = box_heights * height_ratio
# Used for size map
# [max_num_instances, 2]
box_heights_widths = tf.stack([box_heights, box_widths], axis=-1)
# Center/corner heatmaps
# [output_h, output_w, num_classes]
ct_heatmap = tf.zeros((output_h, output_w, num_classes), dtype)
# Maps for offset and size features for each instance of a box
# [max_num_instances, 2]
ct_offset = tf.zeros((max_num_instances, 2), dtype)
# [max_num_instances, 2]
size = tf.zeros((max_num_instances, 2), dtype)
# Mask for valid box instances and their center indices in the heatmap
# [max_num_instances, ]
box_mask = tf.zeros((max_num_instances,), tf.int32)
# [max_num_instances, 2]
box_indices = tf.zeros((max_num_instances, 2), tf.int32)
if num_objects > 0:
# Need to gaussians around the centers and corners of the objects
ct_heatmap = assign_center_targets(
out_height=output_h,
out_width=output_w,
y_center=scale_yct_floor[:num_objects],
x_center=scale_xct_floor[:num_objects],
boxes_height=box_heights[:num_objects],
boxes_width=box_widths[:num_objects],
channel_onehot=tf.one_hot(tf.cast(classes[:num_objects], tf.int32),
num_classes, off_value=0.),
gaussian_iou=gaussian_iou)
# Indices used to update offsets and sizes for valid box instances
update_indices = cartesian_product(
tf.range(max_num_instances), tf.range(2))
# [max_num_instances, 2, 2]
update_indices = tf.reshape(update_indices, shape=[max_num_instances, 2, 2])
# Write the offsets of each box instance
ct_offset = tf.tensor_scatter_nd_update(
ct_offset, update_indices, ct_offset_values)
# Write the size of each bounding box
size = tf.tensor_scatter_nd_update(
size, update_indices, box_heights_widths)
# Initially the mask is zeros, so now we unmask each valid box instance
box_mask = tf.where(tf.range(max_num_instances) < num_objects, 1, 0)
# Write the y and x coordinate of each box center in the heatmap
box_index_values = tf.cast(
tf.stack([scale_yct_floor, scale_xct_floor], axis=-1),
dtype=tf.int32)
box_indices = tf.tensor_scatter_nd_update(
box_indices, update_indices, box_index_values)
ct_labels = {
# [output_h, output_w, num_classes]
'ct_heatmaps': ct_heatmap,
# [max_num_instances, 2]
'ct_offset': ct_offset,
# [max_num_instances, 2]
'size': size,
# [max_num_instances, ]
'box_mask': box_mask,
# [max_num_instances, 2]
'box_indices': box_indices
}
return ct_labels
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for targets generations of centernet."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.projects.centernet.ops import target_assigner
class TargetAssignerTest(tf.test.TestCase, parameterized.TestCase):
def check_labels_correct(self,
boxes,
classes,
output_size,
input_size):
max_num_instances = 128
num_detections = len(boxes)
boxes = tf.constant(boxes, dtype=tf.float32)
classes = tf.constant(classes, dtype=tf.float32)
boxes = preprocess_ops.clip_or_pad_to_fixed_size(
boxes, max_num_instances, 0)
classes = preprocess_ops.clip_or_pad_to_fixed_size(
classes, max_num_instances, 0)
# pylint: disable=g-long-lambda
labels = target_assigner.assign_centernet_targets(
labels={
'boxes': boxes,
'classes': classes,
'groundtruths': {
'num_detections': num_detections,
}
},
output_size=output_size,
input_size=input_size)
ct_heatmaps = labels['ct_heatmaps']
ct_offset = labels['ct_offset']
size = labels['size']
box_mask = labels['box_mask']
box_indices = labels['box_indices']
boxes = tf.cast(boxes, tf.float32)
classes = tf.cast(classes, tf.float32)
height_ratio = output_size[0] / input_size[0]
width_ratio = output_size[1] / input_size[1]
# Shape checks
self.assertEqual(ct_heatmaps.shape, (output_size[0], output_size[1], 90))
self.assertEqual(ct_offset.shape, (max_num_instances, 2))
self.assertEqual(size.shape, (max_num_instances, 2))
self.assertEqual(box_mask.shape, (max_num_instances,))
self.assertEqual(box_indices.shape, (max_num_instances, 2))
self.assertAllInRange(ct_heatmaps, 0, 1)
for i in range(len(boxes)):
# Check sizes
self.assertAllEqual(size[i],
[(boxes[i][2] - boxes[i][0]) * height_ratio,
(boxes[i][3] - boxes[i][1]) * width_ratio,
])
# Check box indices
y = tf.math.floor((boxes[i][0] + boxes[i][2]) / 2 * height_ratio)
x = tf.math.floor((boxes[i][1] + boxes[i][3]) / 2 * width_ratio)
self.assertAllEqual(box_indices[i], [y, x])
# check offsets
true_y = (boxes[i][0] + boxes[i][2]) / 2 * height_ratio
true_x = (boxes[i][1] + boxes[i][3]) / 2 * width_ratio
self.assertAllEqual(ct_offset[i], [true_y - y, true_x - x])
for i in range(len(boxes), max_num_instances):
# Make sure rest are zero
self.assertAllEqual(size[i], [0, 0])
self.assertAllEqual(box_indices[i], [0, 0])
self.assertAllEqual(ct_offset[i], [0, 0])
# Check mask indices
self.assertAllEqual(tf.cast(box_mask[3:], tf.int32),
tf.repeat(0, repeats=max_num_instances - 3))
self.assertAllEqual(tf.cast(box_mask[:3], tf.int32),
tf.repeat(1, repeats=3))
def test_generate_targets_no_scale(self):
boxes = [
(10, 300, 15, 370),
(100, 300, 150, 370),
(15, 100, 200, 170),
]
classes = (1, 2, 3)
sizes = [512, 512]
self.check_labels_correct(boxes=boxes,
classes=classes,
output_size=sizes,
input_size=sizes)
def test_generate_targets_stride_4(self):
boxes = [
(10, 300, 15, 370),
(100, 300, 150, 370),
(15, 100, 200, 170),
]
classes = (1, 2, 3)
output_size = [128, 128]
input_size = [512, 512]
self.check_labels_correct(boxes=boxes,
classes=classes,
output_size=output_size,
input_size=input_size)
def test_generate_targets_stride_8(self):
boxes = [
(10, 300, 15, 370),
(100, 300, 150, 370),
(15, 100, 200, 170),
]
classes = (1, 2, 3)
output_size = [128, 128]
input_size = [1024, 1024]
self.check_labels_correct(boxes=boxes,
classes=classes,
output_size=output_size,
input_size=input_size)
def test_batch_generate_targets(self):
input_size = [512, 512]
output_size = [128, 128]
max_num_instances = 128
boxes = tf.constant([
(10, 300, 15, 370), # center (y, x) = (12, 335)
(100, 300, 150, 370), # center (y, x) = (125, 335)
(15, 100, 200, 170), # center (y, x) = (107, 135)
], dtype=tf.float32)
classes = tf.constant((1, 1, 1), dtype=tf.float32)
boxes = preprocess_ops.clip_or_pad_to_fixed_size(
boxes, max_num_instances, 0)
classes = preprocess_ops.clip_or_pad_to_fixed_size(
classes, max_num_instances, 0)
boxes = tf.stack([boxes, boxes], axis=0)
classes = tf.stack([classes, classes], axis=0)
# pylint: disable=g-long-lambda
labels = tf.map_fn(
fn=lambda x: target_assigner.assign_centernet_targets(
labels=x,
output_size=output_size,
input_size=input_size),
elems={
'boxes': boxes,
'classes': classes,
'groundtruths': {
'num_detections': tf.constant([3, 3]),
}
},
dtype={
'ct_heatmaps': tf.float32,
'ct_offset': tf.float32,
'size': tf.float32,
'box_mask': tf.int32,
'box_indices': tf.int32
}
)
ct_heatmaps = labels['ct_heatmaps']
ct_offset = labels['ct_offset']
size = labels['size']
box_mask = labels['box_mask']
box_indices = labels['box_indices']
self.assertEqual(ct_heatmaps.shape, (2, output_size[0], output_size[1], 90))
self.assertEqual(ct_offset.shape, (2, max_num_instances, 2))
self.assertEqual(size.shape, (2, max_num_instances, 2))
self.assertEqual(box_mask.shape, (2, max_num_instances))
self.assertEqual(box_indices.shape, (2, max_num_instances, 2))
if __name__ == '__main__':
tf.test.main()
This diff is collapsed.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision Centernet trainer."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.vision.beta.projects.centernet.common import registry_imports # pylint: disable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for loading checkpoints."""
import dataclasses
from typing import Dict, Optional
import numpy as np
from official.vision.beta.projects.centernet.utils.checkpoints import config_classes
Conv2DBNCFG = config_classes.Conv2DBNCFG
HeadConvCFG = config_classes.HeadConvCFG
ResidualBlockCFG = config_classes.ResidualBlockCFG
HourglassCFG = config_classes.HourglassCFG
@dataclasses.dataclass
class BackboneConfigData:
"""Backbone Config."""
weights_dict: Optional[Dict[str, np.ndarray]] = dataclasses.field(
repr=False, default=None)
def get_cfg_list(self, name):
"""Get list of block configs for the module."""
if name == 'hourglass104_512':
return [
# Downsampling Layers
Conv2DBNCFG(
weights_dict=self.weights_dict['downsample_input']['conv_block']),
ResidualBlockCFG(
weights_dict=self.weights_dict['downsample_input'][
'residual_block']),
# Hourglass
HourglassCFG(
weights_dict=self.weights_dict['hourglass_network']['0']),
Conv2DBNCFG(
weights_dict=self.weights_dict['output_conv']['0']),
# Intermediate
Conv2DBNCFG(
weights_dict=self.weights_dict['intermediate_conv1']['0']),
Conv2DBNCFG(
weights_dict=self.weights_dict['intermediate_conv2']['0']),
ResidualBlockCFG(
weights_dict=self.weights_dict['intermediate_residual']['0']),
# Hourglass
HourglassCFG(
weights_dict=self.weights_dict['hourglass_network']['1']),
Conv2DBNCFG(
weights_dict=self.weights_dict['output_conv']['1']),
]
elif name == 'extremenet':
return [
# Downsampling Layers
Conv2DBNCFG(
weights_dict=self.weights_dict['downsample_input']['conv_block']),
ResidualBlockCFG(
weights_dict=self.weights_dict['downsample_input'][
'residual_block']),
# Hourglass
HourglassCFG(
weights_dict=self.weights_dict['hourglass_network']['0']),
Conv2DBNCFG(
weights_dict=self.weights_dict['output_conv']['0']),
# Intermediate
Conv2DBNCFG(
weights_dict=self.weights_dict['intermediate_conv1']['0']),
Conv2DBNCFG(
weights_dict=self.weights_dict['intermediate_conv2']['0']),
ResidualBlockCFG(
weights_dict=self.weights_dict['intermediate_residual']['0']),
# Hourglass
HourglassCFG(
weights_dict=self.weights_dict['hourglass_network']['1']),
Conv2DBNCFG(
weights_dict=self.weights_dict['output_conv']['1']),
]
@dataclasses.dataclass
class HeadConfigData:
"""Head Config."""
weights_dict: Optional[Dict[str, np.ndarray]] = dataclasses.field(
repr=False, default=None)
def get_cfg_list(self, name):
if name == 'detection_2d':
return [
HeadConvCFG(weights_dict=self.weights_dict['object_center']['0']),
HeadConvCFG(weights_dict=self.weights_dict['object_center']['1']),
HeadConvCFG(weights_dict=self.weights_dict['box.Soffset']['0']),
HeadConvCFG(weights_dict=self.weights_dict['box.Soffset']['1']),
HeadConvCFG(weights_dict=self.weights_dict['box.Sscale']['0']),
HeadConvCFG(weights_dict=self.weights_dict['box.Sscale']['1'])
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment