Commit e4be7e00 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Removes unneeded content of the beta folder.

PiperOrigin-RevId: 437276665
parent f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class to subsample minibatches by balancing positives and negatives.
Subsamples minibatches based on a pre-specified positive fraction in range
[0,1]. The class presumes there are many more negatives than positive examples:
if the desired batch_size cannot be achieved with the pre-specified positive
fraction, it fills the rest with negative examples. If this is not sufficient
for obtaining the desired batch_size, it returns fewer examples.
The main function to call is Subsample(self, indicator, labels). For convenience
one can also call SubsampleWeights(self, weights, labels) which is defined in
the minibatch_sampler base class.
When is_static is True, it implements a method that guarantees static shapes.
It also ensures the length of output of the subsample is always batch_size, even
when number of examples set to True in indicator is less than batch_size.
This is originally implemented in TensorFlow Object Detection API.
"""
# Import libraries
import tensorflow as tf
def combined_static_and_dynamic_shape(tensor):
"""Returns a list containing static and dynamic values for the dimensions.
Returns a list of static and dynamic values for shape dimensions. This is
useful to preserve static shapes when available in reshape operation.
Args:
tensor: A tensor of any type.
Returns:
A list of size tensor.shape.ndims containing integers or a scalar tensor.
"""
static_tensor_shape = tensor.shape.as_list()
dynamic_tensor_shape = tf.shape(input=tensor)
combined_shape = []
for index, dim in enumerate(static_tensor_shape):
if dim is not None:
combined_shape.append(dim)
else:
combined_shape.append(dynamic_tensor_shape[index])
return combined_shape
def indices_to_dense_vector(indices,
size,
indices_value=1.,
default_value=0,
dtype=tf.float32):
"""Creates dense vector with indices set to specific value and rest to zeros.
This function exists because it is unclear if it is safe to use
tf.sparse_to_dense(indices, [size], 1, validate_indices=False)
with indices which are not ordered.
This function accepts a dynamic size (e.g. tf.shape(tensor)[0])
Args:
indices: 1d Tensor with integer indices which are to be set to
indices_values.
size: scalar with size (integer) of output Tensor.
indices_value: values of elements specified by indices in the output vector
default_value: values of other elements in the output vector.
dtype: data type.
Returns:
dense 1D Tensor of shape [size] with indices set to indices_values and the
rest set to default_value.
"""
size = tf.cast(size, dtype=tf.int32)
zeros = tf.ones([size], dtype=dtype) * default_value
values = tf.ones_like(indices, dtype=dtype) * indices_value
return tf.dynamic_stitch(
[tf.range(size), tf.cast(indices, dtype=tf.int32)], [zeros, values])
def matmul_gather_on_zeroth_axis(params, indices, scope=None):
"""Matrix multiplication based implementation of tf.gather on zeroth axis.
TODO(rathodv, jonathanhuang): enable sparse matmul option.
Args:
params: A float32 Tensor. The tensor from which to gather values.
Must be at least rank 1.
indices: A Tensor. Must be one of the following types: int32, int64.
Must be in range [0, params.shape[0])
scope: A name for the operation (optional).
Returns:
A Tensor. Has the same type as params. Values from params gathered
from indices given by indices, with shape indices.shape + params.shape[1:].
"""
scope = scope or 'MatMulGather'
with tf.name_scope(scope):
params_shape = combined_static_and_dynamic_shape(params)
indices_shape = combined_static_and_dynamic_shape(indices)
params2d = tf.reshape(params, [params_shape[0], -1])
indicator_matrix = tf.one_hot(indices, params_shape[0])
gathered_result_flattened = tf.matmul(indicator_matrix, params2d)
return tf.reshape(gathered_result_flattened,
tf.stack(indices_shape + params_shape[1:]))
class BalancedPositiveNegativeSampler:
"""Subsamples minibatches to a desired balance of positives and negatives."""
def __init__(self, positive_fraction=0.5, is_static=False):
"""Constructs a minibatch sampler.
Args:
positive_fraction: desired fraction of positive examples (scalar in [0,1])
in the batch.
is_static: If True, uses an implementation with static shape guarantees.
Raises:
ValueError: if positive_fraction < 0, or positive_fraction > 1
"""
if positive_fraction < 0 or positive_fraction > 1:
raise ValueError('positive_fraction should be in range [0,1]. '
'Received: %s.' % positive_fraction)
self._positive_fraction = positive_fraction
self._is_static = is_static
@staticmethod
def subsample_indicator(indicator, num_samples):
"""Subsample indicator vector.
Given a boolean indicator vector with M elements set to `True`, the function
assigns all but `num_samples` of these previously `True` elements to
`False`. If `num_samples` is greater than M, the original indicator vector
is returned.
Args:
indicator: a 1-dimensional boolean tensor indicating which elements
are allowed to be sampled and which are not.
num_samples: int32 scalar tensor
Returns:
a boolean tensor with the same shape as input (indicator) tensor
"""
indices = tf.where(indicator)
indices = tf.random.shuffle(indices)
indices = tf.reshape(indices, [-1])
num_samples = tf.minimum(tf.size(input=indices), num_samples)
selected_indices = tf.slice(indices, [0], tf.reshape(num_samples, [1]))
selected_indicator = indices_to_dense_vector(
selected_indices,
tf.shape(input=indicator)[0])
return tf.equal(selected_indicator, 1)
def _get_num_pos_neg_samples(self, sorted_indices_tensor, sample_size):
"""Counts the number of positives and negatives numbers to be sampled.
Args:
sorted_indices_tensor: A sorted int32 tensor of shape [N] which contains
the signed indices of the examples where the sign is based on the label
value. The examples that cannot be sampled are set to 0. It samples
at most sample_size*positive_fraction positive examples and remaining
from negative examples.
sample_size: Size of subsamples.
Returns:
A tuple containing the number of positive and negative labels in the
subsample.
"""
input_length = tf.shape(input=sorted_indices_tensor)[0]
valid_positive_index = tf.greater(sorted_indices_tensor,
tf.zeros(input_length, tf.int32))
num_sampled_pos = tf.reduce_sum(
input_tensor=tf.cast(valid_positive_index, tf.int32))
max_num_positive_samples = tf.constant(
int(sample_size * self._positive_fraction), tf.int32)
num_positive_samples = tf.minimum(max_num_positive_samples, num_sampled_pos)
num_negative_samples = tf.constant(sample_size,
tf.int32) - num_positive_samples
return num_positive_samples, num_negative_samples
def _get_values_from_start_and_end(self, input_tensor, num_start_samples,
num_end_samples, total_num_samples):
"""slices num_start_samples and last num_end_samples from input_tensor.
Args:
input_tensor: An int32 tensor of shape [N] to be sliced.
num_start_samples: Number of examples to be sliced from the beginning
of the input tensor.
num_end_samples: Number of examples to be sliced from the end of the
input tensor.
total_num_samples: Sum of is num_start_samples and num_end_samples. This
should be a scalar.
Returns:
A tensor containing the first num_start_samples and last num_end_samples
from input_tensor.
"""
input_length = tf.shape(input=input_tensor)[0]
start_positions = tf.less(tf.range(input_length), num_start_samples)
end_positions = tf.greater_equal(
tf.range(input_length), input_length - num_end_samples)
selected_positions = tf.logical_or(start_positions, end_positions)
selected_positions = tf.cast(selected_positions, tf.float32)
indexed_positions = tf.multiply(tf.cumsum(selected_positions),
selected_positions)
one_hot_selector = tf.one_hot(tf.cast(indexed_positions, tf.int32) - 1,
total_num_samples,
dtype=tf.float32)
return tf.cast(tf.tensordot(tf.cast(input_tensor, tf.float32),
one_hot_selector, axes=[0, 0]), tf.int32)
def _static_subsample(self, indicator, batch_size, labels):
"""Returns subsampled minibatch.
Args:
indicator: boolean tensor of shape [N] whose True entries can be sampled.
N should be a complie time constant.
batch_size: desired batch size. This scalar cannot be None.
labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples. N should be a complie time constant.
Returns:
sampled_idx_indicator: boolean tensor of shape [N], True for entries which
are sampled. It ensures the length of output of the subsample is always
batch_size, even when number of examples set to True in indicator is
less than batch_size.
Raises:
ValueError: if labels and indicator are not 1D boolean tensors.
"""
# Check if indicator and labels have a static size.
if not indicator.shape.is_fully_defined():
raise ValueError('indicator must be static in shape when is_static is'
'True')
if not labels.shape.is_fully_defined():
raise ValueError('labels must be static in shape when is_static is'
'True')
if not isinstance(batch_size, int):
raise ValueError('batch_size has to be an integer when is_static is'
'True.')
input_length = tf.shape(input=indicator)[0]
# Set the number of examples set True in indicator to be at least
# batch_size.
num_true_sampled = tf.reduce_sum(
input_tensor=tf.cast(indicator, tf.float32))
additional_false_sample = tf.less_equal(
tf.cumsum(tf.cast(tf.logical_not(indicator), tf.float32)),
batch_size - num_true_sampled)
indicator = tf.logical_or(indicator, additional_false_sample)
# Shuffle indicator and label. Need to store the permutation to restore the
# order post sampling.
permutation = tf.random.shuffle(tf.range(input_length))
indicator = matmul_gather_on_zeroth_axis(
tf.cast(indicator, tf.float32), permutation)
labels = matmul_gather_on_zeroth_axis(
tf.cast(labels, tf.float32), permutation)
# index (starting from 1) when indicator is True, 0 when False
indicator_idx = tf.where(
tf.cast(indicator, tf.bool), tf.range(1, input_length + 1),
tf.zeros(input_length, tf.int32))
# Replace -1 for negative, +1 for positive labels
signed_label = tf.where(
tf.cast(labels, tf.bool), tf.ones(input_length, tf.int32),
tf.scalar_mul(-1, tf.ones(input_length, tf.int32)))
# negative of index for negative label, positive index for positive label,
# 0 when indicator is False.
signed_indicator_idx = tf.multiply(indicator_idx, signed_label)
sorted_signed_indicator_idx = tf.nn.top_k(
signed_indicator_idx, input_length, sorted=True).values
[num_positive_samples,
num_negative_samples] = self._get_num_pos_neg_samples(
sorted_signed_indicator_idx, batch_size)
sampled_idx = self._get_values_from_start_and_end(
sorted_signed_indicator_idx, num_positive_samples,
num_negative_samples, batch_size)
# Shift the indices to start from 0 and remove any samples that are set as
# False.
sampled_idx = tf.abs(sampled_idx) - tf.ones(batch_size, tf.int32)
sampled_idx = tf.multiply(
tf.cast(tf.greater_equal(sampled_idx, tf.constant(0)), tf.int32),
sampled_idx)
sampled_idx_indicator = tf.cast(
tf.reduce_sum(
input_tensor=tf.one_hot(sampled_idx, depth=input_length), axis=0),
tf.bool)
# project back the order based on stored permutations
reprojections = tf.one_hot(permutation, depth=input_length,
dtype=tf.float32)
return tf.cast(tf.tensordot(
tf.cast(sampled_idx_indicator, tf.float32),
reprojections, axes=[0, 0]), tf.bool)
def subsample(self, indicator, batch_size, labels, scope=None):
"""Returns subsampled minibatch.
Args:
indicator: boolean tensor of shape [N] whose True entries can be sampled.
batch_size: desired batch size. If None, keeps all positive samples and
randomly selects negative samples so that the positive sample fraction
matches self._positive_fraction. It cannot be None is is_static is True.
labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples.
scope: name scope.
Returns:
sampled_idx_indicator: boolean tensor of shape [N], True for entries which
are sampled.
Raises:
ValueError: if labels and indicator are not 1D boolean tensors.
"""
if len(indicator.get_shape().as_list()) != 1:
raise ValueError('indicator must be 1 dimensional, got a tensor of '
'shape %s' % indicator.get_shape())
if len(labels.get_shape().as_list()) != 1:
raise ValueError('labels must be 1 dimensional, got a tensor of '
'shape %s' % labels.get_shape())
if labels.dtype != tf.bool:
raise ValueError('labels should be of type bool. Received: %s' %
labels.dtype)
if indicator.dtype != tf.bool:
raise ValueError('indicator should be of type bool. Received: %s' %
indicator.dtype)
scope = scope or 'BalancedPositiveNegativeSampler'
with tf.name_scope(scope):
if self._is_static:
return self._static_subsample(indicator, batch_size, labels)
else:
# Only sample from indicated samples
negative_idx = tf.logical_not(labels)
positive_idx = tf.logical_and(labels, indicator)
negative_idx = tf.logical_and(negative_idx, indicator)
# Sample positive and negative samples separately
if batch_size is None:
max_num_pos = tf.reduce_sum(
input_tensor=tf.cast(positive_idx, dtype=tf.int32))
else:
max_num_pos = int(self._positive_fraction * batch_size)
sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
num_sampled_pos = tf.reduce_sum(
input_tensor=tf.cast(sampled_pos_idx, tf.int32))
if batch_size is None:
negative_positive_ratio = (
1 - self._positive_fraction) / self._positive_fraction
max_num_neg = tf.cast(
negative_positive_ratio *
tf.cast(num_sampled_pos, dtype=tf.float32),
dtype=tf.int32)
else:
max_num_neg = batch_size - num_sampled_pos
sampled_neg_idx = self.subsample_indicator(negative_idx, max_num_neg)
return tf.logical_or(sampled_pos_idx, sampled_neg_idx)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spatial transform ops."""
import tensorflow as tf
_EPSILON = 1e-8
def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
"""Feature bilinear interpolation.
The RoIAlign feature f can be computed by bilinear interpolation
of four neighboring feature points f0, f1, f2, and f3.
f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
[f10, f11]]
f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
kernel_y = [hy, ly]
kernel_x = [hx, lx]
Args:
features: The features are in shape of [batch_size, num_boxes, output_size *
2, output_size * 2, num_filters].
kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
Returns:
A 5-D tensor representing feature crop of shape
[batch_size, num_boxes, output_size, output_size, num_filters].
"""
features_shape = tf.shape(features)
batch_size, num_boxes, output_size, num_filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[4])
output_size = output_size // 2
kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1])
kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2])
# Use implicit broadcast to generate the interpolation kernel. The
# multiplier `4` is for avg pooling.
interpolation_kernel = kernel_y * kernel_x * 4
# Interpolate the gathered features with computed interpolation kernels.
features *= tf.cast(
tf.expand_dims(interpolation_kernel, axis=-1), dtype=features.dtype)
features = tf.reshape(
features,
[batch_size * num_boxes, output_size * 2, output_size * 2, num_filters])
features = tf.nn.avg_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
features = tf.reshape(
features, [batch_size, num_boxes, output_size, output_size, num_filters])
return features
def _compute_grid_positions(boxes, boundaries, output_size, sample_offset):
"""Computes the grid position w.r.t. the corresponding feature map.
Args:
boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
information of each box w.r.t. the corresponding feature map.
boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
in terms of the number of pixels of the corresponding feature map size.
boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
the boundary (in (y, x)) of the corresponding feature map for each box.
Any resampled grid points that go beyond the bounary will be clipped.
output_size: a scalar indicating the output crop size.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
Returns:
kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
"""
boxes_shape = tf.shape(boxes)
batch_size, num_boxes = boxes_shape[0], boxes_shape[1]
if batch_size is None:
batch_size = tf.shape(boxes)[0]
box_grid_x = []
box_grid_y = []
for i in range(output_size):
box_grid_x.append(boxes[:, :, 1] +
(i + sample_offset) * boxes[:, :, 3] / output_size)
box_grid_y.append(boxes[:, :, 0] +
(i + sample_offset) * boxes[:, :, 2] / output_size)
box_grid_x = tf.stack(box_grid_x, axis=2)
box_grid_y = tf.stack(box_grid_y, axis=2)
box_grid_y0 = tf.floor(box_grid_y)
box_grid_x0 = tf.floor(box_grid_x)
box_grid_x0 = tf.maximum(tf.cast(0., dtype=box_grid_x0.dtype), box_grid_x0)
box_grid_y0 = tf.maximum(tf.cast(0., dtype=box_grid_y0.dtype), box_grid_y0)
box_grid_x0 = tf.minimum(box_grid_x0, tf.expand_dims(boundaries[:, :, 1], -1))
box_grid_x1 = tf.minimum(box_grid_x0 + 1,
tf.expand_dims(boundaries[:, :, 1], -1))
box_grid_y0 = tf.minimum(box_grid_y0, tf.expand_dims(boundaries[:, :, 0], -1))
box_grid_y1 = tf.minimum(box_grid_y0 + 1,
tf.expand_dims(boundaries[:, :, 0], -1))
box_gridx0x1 = tf.stack([box_grid_x0, box_grid_x1], axis=-1)
box_gridy0y1 = tf.stack([box_grid_y0, box_grid_y1], axis=-1)
# The RoIAlign feature f can be computed by bilinear interpolation of four
# neighboring feature points f0, f1, f2, and f3.
# f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
# [f10, f11]]
# f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
# f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
ly = box_grid_y - box_grid_y0
lx = box_grid_x - box_grid_x0
hy = 1.0 - ly
hx = 1.0 - lx
kernel_y = tf.reshape(
tf.stack([hy, ly], axis=3), [batch_size, num_boxes, output_size, 2, 1])
kernel_x = tf.reshape(
tf.stack([hx, lx], axis=3), [batch_size, num_boxes, output_size, 2, 1])
return kernel_y, kernel_x, box_gridy0y1, box_gridx0x1
def multilevel_crop_and_resize(features,
boxes,
output_size=7,
sample_offset=0.5):
"""Crop and resize on multilevel feature pyramid.
Generate the (output_size, output_size) set of pixels for each input box
by first locating the box into the correct feature level, and then cropping
and resizing it using the correspoding feature map of that level.
Args:
features: A dictionary with key as pyramid level and value as features. The
features are in shape of [batch_size, height_l, width_l, num_filters].
boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents
a box with [y1, x1, y2, x2] in un-normalized coordinates.
output_size: A scalar to indicate the output crop size.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
Returns:
A 5-D tensor representing feature crop of shape
[batch_size, num_boxes, output_size, output_size, num_filters].
"""
with tf.name_scope('multilevel_crop_and_resize'):
levels = list(features.keys())
min_level = int(min(levels))
max_level = int(max(levels))
features_shape = tf.shape(features[str(min_level)])
batch_size, max_feature_height, max_feature_width, num_filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[3])
num_boxes = tf.shape(boxes)[1]
# Stack feature pyramid into a features_all of shape
# [batch_size, levels, height, width, num_filters].
features_all = []
feature_heights = []
feature_widths = []
for level in range(min_level, max_level + 1):
shape = features[str(level)].get_shape().as_list()
feature_heights.append(shape[1])
feature_widths.append(shape[2])
# Concat tensor of [batch_size, height_l * width_l, num_filters] for each
# levels.
features_all.append(
tf.reshape(features[str(level)], [batch_size, -1, num_filters]))
features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters])
# Calculate height_l * width_l for each level.
level_dim_sizes = [
feature_widths[i] * feature_heights[i]
for i in range(len(feature_widths))
]
# level_dim_offsets is accumulated sum of level_dim_size.
level_dim_offsets = [0]
for i in range(len(feature_widths) - 1):
level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
level_dim_offsets = tf.constant(level_dim_offsets, tf.int32)
height_dim_sizes = tf.constant(feature_widths, tf.int32)
# Assigns boxes to the right level.
box_width = boxes[:, :, 3] - boxes[:, :, 1]
box_height = boxes[:, :, 2] - boxes[:, :, 0]
areas_sqrt = tf.sqrt(
tf.cast(box_height, tf.float32) * tf.cast(box_width, tf.float32))
levels = tf.cast(
tf.math.floordiv(
tf.math.log(tf.math.divide_no_nan(areas_sqrt, 224.0)),
tf.math.log(2.0)) + 4.0,
dtype=tf.int32)
# Maps levels between [min_level, max_level].
levels = tf.minimum(max_level, tf.maximum(levels, min_level))
# Projects box location and sizes to corresponding feature levels.
scale_to_level = tf.cast(
tf.pow(tf.constant(2.0), tf.cast(levels, tf.float32)),
dtype=boxes.dtype)
boxes /= tf.expand_dims(scale_to_level, axis=2)
box_width /= scale_to_level
box_height /= scale_to_level
boxes = tf.concat([boxes[:, :, 0:2],
tf.expand_dims(box_height, -1),
tf.expand_dims(box_width, -1)], axis=-1)
# Maps levels to [0, max_level-min_level].
levels -= min_level
level_strides = tf.pow([[2.0]], tf.cast(levels, tf.float32))
boundary = tf.cast(
tf.concat([
tf.expand_dims(
[[tf.cast(max_feature_height, tf.float32)]] / level_strides - 1,
axis=-1),
tf.expand_dims(
[[tf.cast(max_feature_width, tf.float32)]] / level_strides - 1,
axis=-1),
],
axis=-1), boxes.dtype)
# Compute grid positions.
kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = _compute_grid_positions(
boxes, boundary, output_size, sample_offset)
x_indices = tf.cast(
tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
y_indices = tf.cast(
tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
batch_size_offset = tf.tile(
tf.reshape(
tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1]),
[1, num_boxes, output_size * 2, output_size * 2])
# Get level offset for each box. Each box belongs to one level.
levels_offset = tf.tile(
tf.reshape(
tf.gather(level_dim_offsets, levels),
[batch_size, num_boxes, 1, 1]),
[1, 1, output_size * 2, output_size * 2])
y_indices_offset = tf.tile(
tf.reshape(
y_indices * tf.expand_dims(tf.gather(height_dim_sizes, levels), -1),
[batch_size, num_boxes, output_size * 2, 1]),
[1, 1, 1, output_size * 2])
x_indices_offset = tf.tile(
tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
[1, 1, output_size * 2, 1])
indices = tf.reshape(
batch_size_offset + levels_offset + y_indices_offset + x_indices_offset,
[-1])
# TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
# performance.
features_per_box = tf.reshape(
tf.gather(features_r2, indices),
[batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
# Bilinear interpolation.
features_per_box = _feature_bilinear_interpolation(
features_per_box, kernel_y, kernel_x)
return features_per_box
def _selective_crop_and_resize(features,
boxes,
box_levels,
boundaries,
output_size=7,
sample_offset=0.5,
use_einsum_gather=False):
"""Crop and resize boxes on a set of feature maps.
Given multiple features maps indexed by different levels, and a set of boxes
where each box is mapped to a certain level, it selectively crops and resizes
boxes from the corresponding feature maps to generate the box features.
We follow the ROIAlign technique (see https://arxiv.org/pdf/1703.06870.pdf,
figure 3 for reference). Specifically, for each feature map, we select an
(output_size, output_size) set of pixels corresponding to the box location,
and then use bilinear interpolation to select the feature value for each
pixel.
For performance, we perform the gather and interpolation on all layers as a
single operation. In this op the multi-level features are first stacked and
gathered into [2*output_size, 2*output_size] feature points. Then bilinear
interpolation is performed on the gathered feature points to generate
[output_size, output_size] RoIAlign feature map.
Here is the step-by-step algorithm:
1. The multi-level features are gathered into a
[batch_size, num_boxes, output_size*2, output_size*2, num_filters]
Tensor. The Tensor contains four neighboring feature points for each
vertex in the output grid.
2. Compute the interpolation kernel of shape
[batch_size, num_boxes, output_size*2, output_size*2]. The last 2 axis
can be seen as stacking 2x2 interpolation kernels for all vertices in the
output grid.
3. Element-wise multiply the gathered features and interpolation kernel.
Then apply 2x2 average pooling to reduce spatial dimension to
output_size.
Args:
features: a 5-D tensor of shape [batch_size, num_levels, max_height,
max_width, num_filters] where cropping and resizing are based.
boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
information of each box w.r.t. the corresponding feature map.
boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
in terms of the number of pixels of the corresponding feature map size.
box_levels: a 3-D tensor of shape [batch_size, num_boxes, 1] representing
the 0-based corresponding feature level index of each box.
boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
the boundary (in (y, x)) of the corresponding feature map for each box.
Any resampled grid points that go beyond the bounary will be clipped.
output_size: a scalar indicating the output crop size.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
use_einsum_gather: use einsum to replace gather or not. Replacing einsum
with gather can improve performance when feature size is not large, einsum
is friendly with model partition as well. Gather's performance is better
when feature size is very large and there are multiple box levels.
Returns:
features_per_box: a 5-D tensor of shape
[batch_size, num_boxes, output_size, output_size, num_filters]
representing the cropped features.
"""
(batch_size, num_levels, max_feature_height, max_feature_width,
num_filters) = features.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(features)[0]
_, num_boxes, _ = boxes.get_shape().as_list()
kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = _compute_grid_positions(
boxes, boundaries, output_size, sample_offset)
x_indices = tf.cast(
tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
y_indices = tf.cast(
tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
if use_einsum_gather:
# Blinear interpolation is done during the last two gathers:
# f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
# [f10, f11]]
# [[f00, f01],
# [f10, f11]] = tf.einsum(tf.einsum(features, y_one_hot), x_one_hot)
# where [hy, ly] and [hx, lx] are the bilinear interpolation kernel.
y_indices = tf.cast(
tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size, 2]),
dtype=tf.int32)
x_indices = tf.cast(
tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size, 2]),
dtype=tf.int32)
# shape is [batch_size, num_boxes, output_size, 2, height]
grid_y_one_hot = tf.one_hot(
tf.cast(y_indices, tf.int32), max_feature_height, dtype=kernel_y.dtype)
# shape is [batch_size, num_boxes, output_size, 2, width]
grid_x_one_hot = tf.one_hot(
tf.cast(x_indices, tf.int32), max_feature_width, dtype=kernel_x.dtype)
# shape is [batch_size, num_boxes, output_size, height]
grid_y_weight = tf.reduce_sum(
tf.multiply(grid_y_one_hot, kernel_y), axis=-2)
# shape is [batch_size, num_boxes, output_size, width]
grid_x_weight = tf.reduce_sum(
tf.multiply(grid_x_one_hot, kernel_x), axis=-2)
# Gather for y_axis.
# shape is [batch_size, num_boxes, output_size, width, features]
features_per_box = tf.einsum('bmhwf,bmoh->bmowf', features,
tf.cast(grid_y_weight, features.dtype))
# Gather for x_axis.
# shape is [batch_size, num_boxes, output_size, output_size, features]
features_per_box = tf.einsum('bmhwf,bmow->bmhof', features_per_box,
tf.cast(grid_x_weight, features.dtype))
else:
height_dim_offset = max_feature_width
level_dim_offset = max_feature_height * height_dim_offset
batch_dim_offset = num_levels * level_dim_offset
batch_size_offset = tf.tile(
tf.reshape(
tf.range(batch_size) * batch_dim_offset, [batch_size, 1, 1, 1]),
[1, num_boxes, output_size * 2, output_size * 2])
box_levels_offset = tf.tile(
tf.reshape(box_levels * level_dim_offset,
[batch_size, num_boxes, 1, 1]),
[1, 1, output_size * 2, output_size * 2])
y_indices_offset = tf.tile(
tf.reshape(y_indices * height_dim_offset,
[batch_size, num_boxes, output_size * 2, 1]),
[1, 1, 1, output_size * 2])
x_indices_offset = tf.tile(
tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
[1, 1, output_size * 2, 1])
indices = tf.reshape(
batch_size_offset + box_levels_offset + y_indices_offset +
x_indices_offset, [-1])
features = tf.reshape(features, [-1, num_filters])
# TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
# performance.
features_per_box = tf.reshape(
tf.gather(features, indices),
[batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
features_per_box = _feature_bilinear_interpolation(
features_per_box, kernel_y, kernel_x)
return features_per_box
def crop_mask_in_target_box(masks,
boxes,
target_boxes,
output_size,
sample_offset=0,
use_einsum=True):
"""Crop masks in target boxes.
Args:
masks: A tensor with a shape of [batch_size, num_masks, height, width].
boxes: a float tensor representing box cooridnates that tightly enclose
masks with a shape of [batch_size, num_masks, 4] in un-normalized
coordinates. A box is represented by [ymin, xmin, ymax, xmax].
target_boxes: a float tensor representing target box cooridnates for masks
with a shape of [batch_size, num_masks, 4] in un-normalized coordinates. A
box is represented by [ymin, xmin, ymax, xmax].
output_size: A scalar to indicate the output crop size. It currently only
supports to output a square shape outputs.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
use_einsum: Use einsum to replace gather in selective_crop_and_resize.
Returns:
A 4-D tensor representing feature crop of shape
[batch_size, num_boxes, output_size, output_size].
"""
with tf.name_scope('crop_mask_in_target_box'):
# Cast to float32, as the y_transform and other transform variables may
# overflow in float16
masks = tf.cast(masks, tf.float32)
boxes = tf.cast(boxes, tf.float32)
target_boxes = tf.cast(target_boxes, tf.float32)
batch_size, num_masks, height, width = masks.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(masks)[0]
masks = tf.reshape(masks, [batch_size * num_masks, height, width, 1])
# Pad zeros on the boundary of masks.
masks = tf.image.pad_to_bounding_box(masks, 2, 2, height + 4, width + 4)
masks = tf.reshape(masks, [batch_size, num_masks, height+4, width+4, 1])
# Projects target box locations and sizes to corresponding cropped
# mask coordinates.
gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
value=boxes, num_or_size_splits=4, axis=2)
bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
value=target_boxes, num_or_size_splits=4, axis=2)
y_transform = (bb_y_min - gt_y_min) * height / (
gt_y_max - gt_y_min + _EPSILON) + 2
x_transform = (bb_x_min - gt_x_min) * height / (
gt_x_max - gt_x_min + _EPSILON) + 2
h_transform = (bb_y_max - bb_y_min) * width / (
gt_y_max - gt_y_min + _EPSILON)
w_transform = (bb_x_max - bb_x_min) * width / (
gt_x_max - gt_x_min + _EPSILON)
boundaries = tf.concat(
[tf.ones_like(y_transform) * ((height + 4) - 1),
tf.ones_like(x_transform) * ((width + 4) - 1)],
axis=-1)
boundaries = tf.cast(boundaries, dtype=y_transform.dtype)
# Reshape tensors to have the right shape for selective_crop_and_resize.
trasnformed_boxes = tf.concat(
[y_transform, x_transform, h_transform, w_transform], -1)
levels = tf.tile(tf.reshape(tf.range(num_masks), [1, num_masks]),
[batch_size, 1])
cropped_masks = _selective_crop_and_resize(
masks,
trasnformed_boxes,
levels,
boundaries,
output_size,
sample_offset=sample_offset,
use_einsum_gather=use_einsum)
cropped_masks = tf.squeeze(cropped_masks, axis=-1)
return cropped_masks
def nearest_upsampling(data, scale, use_keras_layer=False):
"""Nearest neighbor upsampling implementation.
Args:
data: A tensor with a shape of [batch, height_in, width_in, channels].
scale: An integer multiple to scale resolution of input data.
use_keras_layer: If True, use keras Upsampling2D layer.
Returns:
data_up: A tensor with a shape of
[batch, height_in*scale, width_in*scale, channels]. Same dtype as input
data.
"""
if use_keras_layer:
return tf.keras.layers.UpSampling2D(size=(scale, scale),
interpolation='nearest')(data)
with tf.name_scope('nearest_upsampling'):
bs, _, _, c = data.get_shape().as_list()
shape = tf.shape(input=data)
h = shape[1]
w = shape[2]
bs = -1 if bs is None else bs
# Uses reshape to quickly upsample the input. The nearest pixel is selected
# via tiling.
data = tf.tile(
tf.reshape(data, [bs, h, 1, w, 1, c]), [1, 1, scale, 1, scale, 1])
return tf.reshape(data, [bs, h * scale, w * scale, c])
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of target gather, which gathers targets from indices."""
import tensorflow as tf
class TargetGather:
"""Targer gather for dense object detector."""
def __call__(self, labels, match_indices, mask=None, mask_val=0.0):
"""Labels anchors with ground truth inputs.
B: batch_size
N: number of groundtruth boxes.
Args:
labels: An integer tensor with shape [N, dims] or [B, N, ...] representing
groundtruth labels.
match_indices: An integer tensor with shape [M] or [B, M] representing
match label index.
mask: An boolean tensor with shape [M, dims] or [B, M,...] representing
match labels.
mask_val: An integer to fill in for mask.
Returns:
target: An integer Tensor with shape [M] or [B, M]
Raises:
ValueError: If `labels` is higher than rank 3.
"""
if len(labels.shape) <= 2:
return self._gather_unbatched(labels, match_indices, mask, mask_val)
elif len(labels.shape) == 3:
return self._gather_batched(labels, match_indices, mask, mask_val)
else:
raise ValueError("`TargetGather` does not support `labels` with rank "
"larger than 3, got {}".format(len(labels.shape)))
def _gather_unbatched(self, labels, match_indices, mask, mask_val):
"""Gather based on unbatched labels and boxes."""
num_gt_boxes = tf.shape(labels)[0]
def _assign_when_rows_empty():
if len(labels.shape) > 1:
mask_shape = [match_indices.shape[0], labels.shape[-1]]
else:
mask_shape = [match_indices.shape[0]]
return tf.cast(mask_val, labels.dtype) * tf.ones(
mask_shape, dtype=labels.dtype)
def _assign_when_rows_not_empty():
targets = tf.gather(labels, match_indices)
if mask is None:
return targets
else:
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype)
return tf.where(mask, masked_targets, targets)
return tf.cond(tf.greater(num_gt_boxes, 0),
_assign_when_rows_not_empty,
_assign_when_rows_empty)
def _gather_batched(self, labels, match_indices, mask, mask_val):
"""Gather based on batched labels."""
batch_size = labels.shape[0]
if batch_size == 1:
if mask is not None:
result = self._gather_unbatched(
tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
tf.squeeze(mask, axis=0), mask_val)
else:
result = self._gather_unbatched(
tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
None, mask_val)
return tf.expand_dims(result, axis=0)
else:
indices_shape = tf.shape(match_indices)
indices_dtype = match_indices.dtype
batch_indices = (tf.expand_dims(
tf.range(indices_shape[0], dtype=indices_dtype), axis=-1) *
tf.ones([1, indices_shape[-1]], dtype=indices_dtype))
gather_nd_indices = tf.stack(
[batch_indices, match_indices], axis=-1)
targets = tf.gather_nd(labels, gather_nd_indices)
if mask is None:
return targets
else:
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype)
return tf.where(mask, masked_targets, targets)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for target_gather.py."""
import tensorflow as tf
from official.vision.beta.ops import target_gather
class TargetGatherTest(tf.test.TestCase):
def test_target_gather_batched(self):
gt_boxes = tf.constant(
[[
[0, 0, 5, 5],
[0, 5, 5, 10],
[5, 0, 10, 5],
[5, 5, 10, 10],
]],
dtype=tf.float32)
gt_classes = tf.constant([[[2], [10], [3], [-1]]], dtype=tf.int32)
labeler = target_gather.TargetGather()
match_indices = tf.constant([[0, 2]], dtype=tf.int32)
match_indicators = tf.constant([[-2, 1]])
mask = tf.less_equal(match_indicators, 0)
cls_mask = tf.expand_dims(mask, -1)
matched_gt_classes = labeler(gt_classes, match_indices, cls_mask)
box_mask = tf.tile(cls_mask, [1, 1, 4])
matched_gt_boxes = labeler(gt_boxes, match_indices, box_mask)
self.assertAllEqual(
matched_gt_classes.numpy(), [[[0], [3]]])
self.assertAllClose(
matched_gt_boxes.numpy(), [[[0, 0, 0, 0], [5, 0, 10, 5]]])
def test_target_gather_unbatched(self):
gt_boxes = tf.constant(
[
[0, 0, 5, 5],
[0, 5, 5, 10],
[5, 0, 10, 5],
[5, 5, 10, 10],
],
dtype=tf.float32)
gt_classes = tf.constant([[2], [10], [3], [-1]], dtype=tf.int32)
labeler = target_gather.TargetGather()
match_indices = tf.constant([0, 2], dtype=tf.int32)
match_indicators = tf.constant([-2, 1])
mask = tf.less_equal(match_indicators, 0)
cls_mask = tf.expand_dims(mask, -1)
matched_gt_classes = labeler(gt_classes, match_indices, cls_mask)
box_mask = tf.tile(cls_mask, [1, 4])
matched_gt_boxes = labeler(gt_boxes, match_indices, box_mask)
self.assertAllEqual(
matched_gt_classes.numpy(), [[0], [3]])
self.assertAllClose(
matched_gt_boxes.numpy(), [[0, 0, 0, 0], [5, 0, 10, 5]])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection input and model functions for serving/inference."""
from typing import Mapping, Text
import tensorflow as tf
from official.vision.beta import configs
from official.vision.beta.modeling import factory
from official.vision.beta.ops import anchor
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.serving import export_base
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class DetectionModule(export_base.ExportModule):
"""Detection Module."""
def _build_model(self):
if self._batch_size is None:
raise ValueError('batch_size cannot be None for detection models.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
model = factory.build_maskrcnn(
input_specs=input_specs, model_config=self.params.task.model)
elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
model = factory.build_retinanet(
input_specs=input_specs, model_config=self.params.task.model)
else:
raise ValueError('Detection module not implemented for {} model.'.format(
type(self.params.task.model)))
return model
def _build_anchor_boxes(self):
"""Builds and returns anchor boxes."""
model_params = self.params.task.model
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
return input_anchor(
image_size=(self._input_image_size[0], self._input_image_size[1]))
def _build_inputs(self, image):
"""Builds detection model inputs for serving."""
model_params = self.params.task.model
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=preprocess_ops.compute_padded_size(
self._input_image_size, 2**model_params.max_level),
aug_scale_min=1.0,
aug_scale_max=1.0)
anchor_boxes = self._build_anchor_boxes()
return image, anchor_boxes, image_info
def preprocess(self, images: tf.Tensor) -> (
tf.Tensor, Mapping[Text, tf.Tensor], tf.Tensor):
"""Preprocess inputs to be suitable for the model.
Args:
images: The images tensor.
Returns:
images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
"""
model_params = self.params.task.model
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
images_spec = tf.TensorSpec(shape=self._input_image_size + [3],
dtype=tf.float32)
num_anchors = model_params.anchor.num_scales * len(
model_params.anchor.aspect_ratios) * 4
anchor_shapes = []
for level in range(model_params.min_level, model_params.max_level + 1):
anchor_level_spec = tf.TensorSpec(
shape=[
self._input_image_size[0] // 2**level,
self._input_image_size[1] // 2**level, num_anchors
],
dtype=tf.float32)
anchor_shapes.append((str(level), anchor_level_spec))
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, anchor_boxes, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=(images_spec, dict(anchor_shapes),
image_info_spec),
parallel_iterations=32))
return images, anchor_boxes, image_info
def serve(self, images: tf.Tensor):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
images, anchor_boxes, image_info = self.preprocess(images)
else:
with tf.device('cpu:0'):
anchor_boxes = self._build_anchor_boxes()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info = tf.convert_to_tensor([[
self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
]],
dtype=tf.float32)
input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that
# have multiple inputs, we use `model.call` here to trigger the forward
# path. Note that, this disables some keras magics happens in `__call__`.
detections = self.model.call(
images=images,
image_shape=input_image_shape,
anchor_boxes=anchor_boxes,
training=False)
if self.params.task.model.detection_generator.apply_nms:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
detection_boxes = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
detections['detection_boxes'] = box_ops.normalize_boxes(
detection_boxes, image_info[:, 0:1, :])
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if export_config.cast_num_detections_to_float:
detections['num_detections'] = tf.cast(
detections['num_detections'], dtype=tf.float32)
if export_config.cast_detection_classes_to_float:
detections['detection_classes'] = tf.cast(
detections['detection_classes'], dtype=tf.float32)
final_outputs = {
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections']
}
else:
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
}
if 'detection_masks' in detections.keys():
final_outputs['detection_masks'] = detections['detection_masks']
final_outputs.update({'image_info': image_info})
return final_outputs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for image detection export lib."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import exp_factory
from official.vision.beta import configs # pylint: disable=unused-import
from official.vision.beta.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_detection_module(self, experiment_name, input_type):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.nms_version = 'batched'
detection_module = detection.DetectionModule(
params,
batch_size=1,
input_image_size=[640, 640],
input_type=input_type)
return detection_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type."""
h, w = image_size
if input_type == 'image_tensor':
return tf.zeros((batch_size, h, w, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue() for b in range(batch_size)]
elif input_type == 'tf_example':
image_tensor = tf.zeros((h, w, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example for b in range(batch_size)]
elif input_type == 'tflite':
return tf.zeros((batch_size, h, w, 3), dtype=np.float32)
@parameterized.parameters(
('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]),
('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('tflite', 'retinanet_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]),
('tflite', 'retinanet_spinenet_coco', [640, 640]),
)
def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(experiment_name, input_type)
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
if input_type == 'tflite':
processed_images = tf.zeros(image_size + [3], dtype=tf.float32)
anchor_boxes = module._build_anchor_boxes()
image_info = tf.convert_to_tensor(
[image_size, image_size, [1.0, 1.0], [0, 0]], dtype=tf.float32)
else:
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :]
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = module.model(
images=processed_images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
training=False)
outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex(
ValueError, 'batch_size cannot be None for detection models.'):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
import abc
from typing import Dict, List, Mapping, Optional, Text
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self,
params: cfg.ExperimentConfig,
*,
batch_size: int,
input_image_size: List[int],
input_type: str = 'image_tensor',
num_channels: int = 3,
model: Optional[tf.keras.Model] = None):
"""Initializes a module for export.
Args:
params: Experiment params.
batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
input_type: The input signature type.
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported.
"""
self.params = params
self._batch_size = batch_size
self._input_image_size = input_image_size
self._num_channels = num_channels
self._input_type = input_type
if model is None:
model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model)
def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
Returns:
A decoded image tensor.
"""
if len(self._input_image_size) == 2:
# Decode an image if 2D input is expected.
image_tensor = tf.image.decode_image(
encoded_image_bytes, channels=self._num_channels)
image_tensor.set_shape((None, None, self._num_channels))
else:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)
image_tensor = tf.reshape(image_tensor,
self._input_image_size + [self._num_channels])
return image_tensor
def _decode_tf_example(
self, tf_example_string_tensor: tf.train.Example) -> tf.Tensor:
"""Decodes a TF Example to an image tensor.
Args:
tf_example_string_tensor: A tf.train.Example of encoded image and other
information.
Returns:
A decoded image tensor.
"""
keys_to_features = {'image/encoded': tf.io.FixedLenFeature((), tf.string)}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = self._decode_image(parsed_tensors['image/encoded'])
return image_tensor
def _build_model(self, **kwargs):
"""Returns a model built from the params."""
return None
@tf.function
def inference_from_image_tensors(
self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_for_tflite(self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_from_image_bytes(self, inputs: tf.Tensor):
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._decode_image,
elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
parallel_iterations=32))
images = tf.stack(images)
return self.serve(images)
@tf.function
def inference_from_tf_example(self,
inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._decode_tf_example,
elems=inputs,
# Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model,
# by _run_inference_on_image_tensors() below.
fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
dtype=tf.uint8,
parallel_iterations=32))
images = tf.stack(images)
return self.serve(images)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8)
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'image_bytes':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_image_bytes.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
elif key == 'tflite':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size +
[self._num_channels],
dtype=tf.float32)
signatures[def_name] = self.inference_for_tflite.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
from typing import Dict, Optional, Text, Callable, Any, Union
import tensorflow as tf
from official.core import export_base
class ExportModule(export_base.ExportModule):
"""Base Export Module."""
def __init__(self,
params,
model: tf.keras.Model,
input_signature: Union[tf.TensorSpec, Dict[str, tf.TensorSpec]],
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(
params,
model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature
@tf.function
def serve(self, inputs):
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return x
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for _, def_name in function_keys.items():
signatures[def_name] = self.serve.get_concrete_function(
self.input_signature)
return signatures
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base_v2."""
import os
import tensorflow as tf
from official.core import export_base
from official.vision.beta.serving import export_base_v2
class TestModel(tf.keras.Model):
def __init__(self):
super().__init__()
self._dense = tf.keras.layers.Dense(2)
def call(self, inputs):
return {'outputs': self._dense(inputs)}
class ExportBaseTest(tf.test.TestCase):
def test_preprocessor(self):
tmp_dir = self.get_temp_dir()
model = TestModel()
inputs = tf.ones([2, 4], tf.float32)
preprocess_fn = lambda inputs: 2 * inputs
module = export_base_v2.ExportModule(
params=None,
input_signature=tf.TensorSpec(shape=[2, 4]),
model=model,
preprocessor=preprocess_fn)
expected_output = model(preprocess_fn(inputs))
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['serving_default'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['serving_default'](inputs)
print('output', output)
self.assertAllClose(
output['outputs'].numpy(), expected_output['outputs'].numpy())
def test_postprocessor(self):
tmp_dir = self.get_temp_dir()
model = TestModel()
inputs = tf.ones([2, 4], tf.float32)
postprocess_fn = lambda logits: {'outputs': 2 * logits['outputs']}
module = export_base_v2.ExportModule(
params=None,
model=model,
input_signature=tf.TensorSpec(shape=[2, 4]),
postprocessor=postprocess_fn)
expected_output = postprocess_fn(model(inputs))
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['serving_default'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['serving_default'](inputs)
self.assertAllClose(
output['outputs'].numpy(), expected_output['outputs'].numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for vision export modules."""
from typing import List, Optional
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision.beta import configs
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.modeling import factory
from official.vision.beta.serving import export_base_v2 as export_base
from official.vision.beta.serving import export_utils
def create_classification_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3):
"""Creats classification export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
input_specs = tf.keras.layers.InputSpec(
shape=[batch_size] + input_image_size + [num_channels])
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None)
def preprocess_fn(inputs):
image_tensor = export_utils.parse_image(inputs, input_type,
input_image_size, num_channels)
# If input_type is `tflite`, do not apply image preprocessing.
if input_type == 'tflite':
return image_tensor
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(
inputs, input_image_size, num_channels)
images = tf.map_fn(
preprocess_image_fn, elems=image_tensor,
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [num_channels],
dtype=tf.float32))
return images
def postprocess_fn(logits):
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
export_module = export_base.ExportModule(params,
model=model,
input_signature=input_signature,
preprocessor=preprocess_fn,
postprocessor=postprocess_fn)
return export_module
def get_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
"""Factory for export modules."""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = create_classification_export_module(
params, input_type, batch_size, input_image_size, num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
return export_module
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for vision modules."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.core import export_base
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.serving import export_module_factory
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self, input_type, input_image_size):
params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18
module = export_module_factory.create_classification_export_module(
params, input_type, batch_size=1, input_image_size=input_image_size)
return module
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return tf.zeros((1, 32, 32, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((32, 32, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type='image_tensor'):
input_image_size = [32, 32]
tmp_dir = self.get_temp_dir()
module = self._get_classification_module(input_type, input_image_size)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
ckpt_path = tf.train.Checkpoint(model=module.model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, [input_type],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(export_dir)
classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(
inputs, input_image_size, num_channels=3)
processed_images = tf.map_fn(
preprocess_image_fn,
elems=tf.zeros([1] + input_image_size + [3], dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [3], dtype=tf.float32))
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(
out['logits'].numpy(), expected_logits.numpy(), rtol=1e-04, atol=1e-04)
self.assertAllClose(
out['probs'].numpy(), expected_prob.numpy(), rtol=1e-04, atol=1e-04)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
flags.DEFINE_string('export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_checkpoint_subdir=FLAGS.export_checkpoint_subdir,
export_saved_model_subdir=FLAGS.export_saved_model_subdir)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision.beta import configs
from official.vision.beta.serving import detection
from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation
from official.vision.beta.serving import video_classification
def export_inference_graph(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
and model parameters to model_params.txt.
"""
if export_checkpoint_subdir:
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
export_dir, export_saved_model_subdir)
else:
output_saved_model_directory = export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if not export_module:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task,
configs.video_classification.VideoClassificationTask):
export_module = video_classification.VideoClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
export_base.export(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
if log_model_flops_and_params:
inputs_kwargs = None
if isinstance(
params.task,
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
inputs_kwargs = {
'images':
tf.TensorSpec([1] + input_image_size + [num_channels],
tf.float32),
'image_shape':
tf.TensorSpec([1, 2], tf.float32)
}
dummy_inputs = {
k: tf.ones(v.shape.as_list(), tf.float32)
for k, v in inputs_kwargs.items()
}
# Must do forward pass to build the model.
export_module.model(**dummy_inputs)
else:
logging.info(
'Logging model flops and params not implemented for %s task.',
type(params.task))
return
train_utils.try_count_flops(export_module.model, inputs_kwargs,
os.path.join(export_dir, 'model_flops.txt'))
train_utils.write_model_params(export_module.model,
os.path.join(export_dir, 'model_params.txt'))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_saved_model_lib."""
import os
from unittest import mock
import tensorflow as tf
from official.core import export_base
from official.vision.beta import configs
from official.vision.beta.serving import export_saved_model_lib
class WriteModelFlopsAndParamsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.tempdir = self.create_tempdir()
self.enter_context(
mock.patch.object(export_base, 'export', autospec=True, spec_set=True))
def _export_model_with_log_model_flops_and_params(self, params):
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[64, 64],
params=params,
checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'),
export_dir=self.tempdir,
log_model_flops_and_params=True)
def assertModelAnalysisFilesExist(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_params.txt')))
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_flops.txt')))
def test_retinanet_task(self):
params = configs.retinanet.retinanet_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
def test_maskrcnn_task(self):
params = configs.maskrcnn.maskrcnn_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision.beta.serving import export_module_factory
def export(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None):
"""Exports the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
"""
if export_checkpoint_subdir:
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
export_dir, export_saved_model_subdir)
else:
output_saved_model_directory = export_dir
export_module = export_module_factory.get_export_module(
params,
input_type=input_type,
batch_size=batch_size,
input_image_size=input_image_size,
num_channels=num_channels)
export_base.export(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export the image classification as a TF-Hub SavedModel."""
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.modeling import factory
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. resnet_imagenet')
flags.DEFINE_string(
'checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_string(
'export_path', None, 'The export directory.')
flags.DEFINE_multi_string(
'config_file',
None,
'A YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_image_size',
'224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_boolean(
'skip_logits_layer',
False,
'Whether to skip the prediction layer and only output the feature vector.')
def export_model_to_tfhub(params,
batch_size,
input_image_size,
skip_logits_layer,
checkpoint_path,
export_path):
"""Export an image classification model to TF-Hub."""
input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
input_image_size + [3])
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
model.save(export_path, include_optimizer=False, save_format='tf')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_model_to_tfhub(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
skip_logits_layer=FLAGS.skip_logits_layer,
checkpoint_path=FLAGS.checkpoint_path,
export_path=FLAGS.export_path)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Binary to convert a saved model to tflite model.
It requires a SavedModel exported using export_saved_model.py with batch size 1
and input type `tflite`, and using the same config file used for exporting saved
model. It includes optional post-training quantization. When using integer
quantization, calibration steps need to be provided to calibrate model input.
To convert a SavedModel to a TFLite model:
EXPERIMENT_TYPE = XX
TFLITE_PATH = XX
SAVED_MOODEL_DIR = XX
CONFIG_FILE = XX
export_tflite --experiment=${EXPERIMENT_TYPE} \
--saved_model_dir=${SAVED_MOODEL_DIR} \
--tflite_path=${TFLITE_PATH} \
--config_file=${CONFIG_FILE} \
--quant_type=fp16 \
--calibration_steps=500
"""
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import export_tflite_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment',
None,
'experiment type, e.g. retinanet_resnetfpn_coco',
required=True)
flags.DEFINE_multi_string(
'config_file',
default='',
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_string(
'saved_model_dir', None, 'The directory to the saved model.', required=True)
flags.DEFINE_string(
'tflite_path', None, 'The path to the output tflite model.', required=True)
flags.DEFINE_string(
'quant_type',
default=None,
help='Post training quantization type. Support `int8`, `int8_full`, '
'`fp16`, and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.')
flags.DEFINE_integer('calibration_steps', 500,
'The number of calibration steps for integer model.')
def main(_) -> None:
params = exp_factory.get_exp_config(FLAGS.experiment)
if FLAGS.config_file is not None:
for config_file in FLAGS.config_file:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
logging.info('Converting SavedModel from %s to TFLite model...',
FLAGS.saved_model_dir)
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=FLAGS.saved_model_dir,
quant_type=FLAGS.quant_type,
params=params,
calibration_steps=FLAGS.calibration_steps)
with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw:
fw.write(tflite_model)
logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import functools
from typing import Iterator, List, Optional
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision.beta import configs
from official.vision.beta import tasks
def create_representative_dataset(
params: cfg.ExperimentConfig) -> tf.data.Dataset:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
Returns:
A tf.data.Dataset instance.
Raises:
ValueError: If task is not supported.
"""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
task = tasks.maskrcnn.MaskRCNNTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
else:
raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32'
logging.info('Task config: %s', params.task.as_dict())
return task.build_inputs(params=params.task.train_data)
def representative_dataset(
params: cfg.ExperimentConfig,
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset = create_representative_dataset(params=params)
for image, _ in dataset.take(calibration_steps):
# Skip images that do not have 3 channels.
if image.shape[-1] != 3:
continue
yield [image]
def convert_tflite_model(saved_model_dir: str,
quant_type: Optional[str] = None,
params: Optional[cfg.ExperimentConfig] = None,
calibration_steps: Optional[int] = 2000) -> bytes:
"""Converts and returns a TFLite model.
Args:
saved_model_dir: The directory to the SavedModel.
quant_type: The post training quantization (PTQ) method. It can be one of
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
calibration_steps: The steps to do calibration.
Returns:
A converted TFLite model with optional PTQ.
Raises:
ValueError: If `representative_dataset_path` is not present if integer
quantization is requested.
"""
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
if quant_type:
if quant_type.startswith('int8'):
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = functools.partial(
representative_dataset,
params=params,
calibration_steps=calibration_steps)
if quant_type == 'int8_full':
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
elif quant_type == 'default':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
return converter.convert()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_tflite_lib."""
import os
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.core import exp_factory
from official.vision.beta import configs # pylint: disable=unused-import
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import detection as detection_serving
from official.vision.beta.serving import export_tflite_lib
from official.vision.beta.serving import image_classification as image_classification_serving
from official.vision.beta.serving import semantic_segmentation as semantic_segmentation_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
def _export_from_module(self, module, input_type, saved_model_dir):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, saved_model_dir, signatures=signatures)
@combinations.generate(
combinations.combine(
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[224, 224]]))
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[384, 384]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment