Commit c8e6faf7 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 431756117
parent 13a5e4fb
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spatial transform ops."""
import tensorflow as tf
_EPSILON = 1e-8
def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
"""Feature bilinear interpolation.
The RoIAlign feature f can be computed by bilinear interpolation
of four neighboring feature points f0, f1, f2, and f3.
f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
[f10, f11]]
f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
kernel_y = [hy, ly]
kernel_x = [hx, lx]
Args:
features: The features are in shape of [batch_size, num_boxes, output_size *
2, output_size * 2, num_filters].
kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
Returns:
A 5-D tensor representing feature crop of shape
[batch_size, num_boxes, output_size, output_size, num_filters].
"""
features_shape = tf.shape(features)
batch_size, num_boxes, output_size, num_filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[4])
output_size = output_size // 2
kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1])
kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2])
# Use implicit broadcast to generate the interpolation kernel. The
# multiplier `4` is for avg pooling.
interpolation_kernel = kernel_y * kernel_x * 4
# Interpolate the gathered features with computed interpolation kernels.
features *= tf.cast(
tf.expand_dims(interpolation_kernel, axis=-1), dtype=features.dtype)
features = tf.reshape(
features,
[batch_size * num_boxes, output_size * 2, output_size * 2, num_filters])
features = tf.nn.avg_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
features = tf.reshape(
features, [batch_size, num_boxes, output_size, output_size, num_filters])
return features
def _compute_grid_positions(boxes, boundaries, output_size, sample_offset):
"""Computes the grid position w.r.t. the corresponding feature map.
Args:
boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
information of each box w.r.t. the corresponding feature map.
boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
in terms of the number of pixels of the corresponding feature map size.
boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
the boundary (in (y, x)) of the corresponding feature map for each box.
Any resampled grid points that go beyond the bounary will be clipped.
output_size: a scalar indicating the output crop size.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
Returns:
kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
"""
boxes_shape = tf.shape(boxes)
batch_size, num_boxes = boxes_shape[0], boxes_shape[1]
if batch_size is None:
batch_size = tf.shape(boxes)[0]
box_grid_x = []
box_grid_y = []
for i in range(output_size):
box_grid_x.append(boxes[:, :, 1] +
(i + sample_offset) * boxes[:, :, 3] / output_size)
box_grid_y.append(boxes[:, :, 0] +
(i + sample_offset) * boxes[:, :, 2] / output_size)
box_grid_x = tf.stack(box_grid_x, axis=2)
box_grid_y = tf.stack(box_grid_y, axis=2)
box_grid_y0 = tf.floor(box_grid_y)
box_grid_x0 = tf.floor(box_grid_x)
box_grid_x0 = tf.maximum(tf.cast(0., dtype=box_grid_x0.dtype), box_grid_x0)
box_grid_y0 = tf.maximum(tf.cast(0., dtype=box_grid_y0.dtype), box_grid_y0)
box_grid_x0 = tf.minimum(box_grid_x0, tf.expand_dims(boundaries[:, :, 1], -1))
box_grid_x1 = tf.minimum(box_grid_x0 + 1,
tf.expand_dims(boundaries[:, :, 1], -1))
box_grid_y0 = tf.minimum(box_grid_y0, tf.expand_dims(boundaries[:, :, 0], -1))
box_grid_y1 = tf.minimum(box_grid_y0 + 1,
tf.expand_dims(boundaries[:, :, 0], -1))
box_gridx0x1 = tf.stack([box_grid_x0, box_grid_x1], axis=-1)
box_gridy0y1 = tf.stack([box_grid_y0, box_grid_y1], axis=-1)
# The RoIAlign feature f can be computed by bilinear interpolation of four
# neighboring feature points f0, f1, f2, and f3.
# f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
# [f10, f11]]
# f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
# f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
ly = box_grid_y - box_grid_y0
lx = box_grid_x - box_grid_x0
hy = 1.0 - ly
hx = 1.0 - lx
kernel_y = tf.reshape(
tf.stack([hy, ly], axis=3), [batch_size, num_boxes, output_size, 2, 1])
kernel_x = tf.reshape(
tf.stack([hx, lx], axis=3), [batch_size, num_boxes, output_size, 2, 1])
return kernel_y, kernel_x, box_gridy0y1, box_gridx0x1
def multilevel_crop_and_resize(features,
boxes,
output_size=7,
sample_offset=0.5):
"""Crop and resize on multilevel feature pyramid.
Generate the (output_size, output_size) set of pixels for each input box
by first locating the box into the correct feature level, and then cropping
and resizing it using the correspoding feature map of that level.
Args:
features: A dictionary with key as pyramid level and value as features. The
features are in shape of [batch_size, height_l, width_l, num_filters].
boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents
a box with [y1, x1, y2, x2] in un-normalized coordinates.
output_size: A scalar to indicate the output crop size.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
Returns:
A 5-D tensor representing feature crop of shape
[batch_size, num_boxes, output_size, output_size, num_filters].
"""
with tf.name_scope('multilevel_crop_and_resize'):
levels = list(features.keys())
min_level = int(min(levels))
max_level = int(max(levels))
features_shape = tf.shape(features[str(min_level)])
batch_size, max_feature_height, max_feature_width, num_filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[3])
num_boxes = tf.shape(boxes)[1]
# Stack feature pyramid into a features_all of shape
# [batch_size, levels, height, width, num_filters].
features_all = []
feature_heights = []
feature_widths = []
for level in range(min_level, max_level + 1):
shape = features[str(level)].get_shape().as_list()
feature_heights.append(shape[1])
feature_widths.append(shape[2])
# Concat tensor of [batch_size, height_l * width_l, num_filters] for each
# levels.
features_all.append(
tf.reshape(features[str(level)], [batch_size, -1, num_filters]))
features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters])
# Calculate height_l * width_l for each level.
level_dim_sizes = [
feature_widths[i] * feature_heights[i]
for i in range(len(feature_widths))
]
# level_dim_offsets is accumulated sum of level_dim_size.
level_dim_offsets = [0]
for i in range(len(feature_widths) - 1):
level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
level_dim_offsets = tf.constant(level_dim_offsets, tf.int32)
height_dim_sizes = tf.constant(feature_widths, tf.int32)
# Assigns boxes to the right level.
box_width = boxes[:, :, 3] - boxes[:, :, 1]
box_height = boxes[:, :, 2] - boxes[:, :, 0]
areas_sqrt = tf.sqrt(
tf.cast(box_height, tf.float32) * tf.cast(box_width, tf.float32))
levels = tf.cast(
tf.math.floordiv(
tf.math.log(tf.math.divide_no_nan(areas_sqrt, 224.0)),
tf.math.log(2.0)) + 4.0,
dtype=tf.int32)
# Maps levels between [min_level, max_level].
levels = tf.minimum(max_level, tf.maximum(levels, min_level))
# Projects box location and sizes to corresponding feature levels.
scale_to_level = tf.cast(
tf.pow(tf.constant(2.0), tf.cast(levels, tf.float32)),
dtype=boxes.dtype)
boxes /= tf.expand_dims(scale_to_level, axis=2)
box_width /= scale_to_level
box_height /= scale_to_level
boxes = tf.concat([boxes[:, :, 0:2],
tf.expand_dims(box_height, -1),
tf.expand_dims(box_width, -1)], axis=-1)
# Maps levels to [0, max_level-min_level].
levels -= min_level
level_strides = tf.pow([[2.0]], tf.cast(levels, tf.float32))
boundary = tf.cast(
tf.concat([
tf.expand_dims(
[[tf.cast(max_feature_height, tf.float32)]] / level_strides - 1,
axis=-1),
tf.expand_dims(
[[tf.cast(max_feature_width, tf.float32)]] / level_strides - 1,
axis=-1),
],
axis=-1), boxes.dtype)
# Compute grid positions.
kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = _compute_grid_positions(
boxes, boundary, output_size, sample_offset)
x_indices = tf.cast(
tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
y_indices = tf.cast(
tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
batch_size_offset = tf.tile(
tf.reshape(
tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1]),
[1, num_boxes, output_size * 2, output_size * 2])
# Get level offset for each box. Each box belongs to one level.
levels_offset = tf.tile(
tf.reshape(
tf.gather(level_dim_offsets, levels),
[batch_size, num_boxes, 1, 1]),
[1, 1, output_size * 2, output_size * 2])
y_indices_offset = tf.tile(
tf.reshape(
y_indices * tf.expand_dims(tf.gather(height_dim_sizes, levels), -1),
[batch_size, num_boxes, output_size * 2, 1]),
[1, 1, 1, output_size * 2])
x_indices_offset = tf.tile(
tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
[1, 1, output_size * 2, 1])
indices = tf.reshape(
batch_size_offset + levels_offset + y_indices_offset + x_indices_offset,
[-1])
# TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
# performance.
features_per_box = tf.reshape(
tf.gather(features_r2, indices),
[batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
# Bilinear interpolation.
features_per_box = _feature_bilinear_interpolation(
features_per_box, kernel_y, kernel_x)
return features_per_box
def _selective_crop_and_resize(features,
boxes,
box_levels,
boundaries,
output_size=7,
sample_offset=0.5,
use_einsum_gather=False):
"""Crop and resize boxes on a set of feature maps.
Given multiple features maps indexed by different levels, and a set of boxes
where each box is mapped to a certain level, it selectively crops and resizes
boxes from the corresponding feature maps to generate the box features.
We follow the ROIAlign technique (see https://arxiv.org/pdf/1703.06870.pdf,
figure 3 for reference). Specifically, for each feature map, we select an
(output_size, output_size) set of pixels corresponding to the box location,
and then use bilinear interpolation to select the feature value for each
pixel.
For performance, we perform the gather and interpolation on all layers as a
single operation. In this op the multi-level features are first stacked and
gathered into [2*output_size, 2*output_size] feature points. Then bilinear
interpolation is performed on the gathered feature points to generate
[output_size, output_size] RoIAlign feature map.
Here is the step-by-step algorithm:
1. The multi-level features are gathered into a
[batch_size, num_boxes, output_size*2, output_size*2, num_filters]
Tensor. The Tensor contains four neighboring feature points for each
vertex in the output grid.
2. Compute the interpolation kernel of shape
[batch_size, num_boxes, output_size*2, output_size*2]. The last 2 axis
can be seen as stacking 2x2 interpolation kernels for all vertices in the
output grid.
3. Element-wise multiply the gathered features and interpolation kernel.
Then apply 2x2 average pooling to reduce spatial dimension to
output_size.
Args:
features: a 5-D tensor of shape [batch_size, num_levels, max_height,
max_width, num_filters] where cropping and resizing are based.
boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
information of each box w.r.t. the corresponding feature map.
boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
in terms of the number of pixels of the corresponding feature map size.
box_levels: a 3-D tensor of shape [batch_size, num_boxes, 1] representing
the 0-based corresponding feature level index of each box.
boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
the boundary (in (y, x)) of the corresponding feature map for each box.
Any resampled grid points that go beyond the bounary will be clipped.
output_size: a scalar indicating the output crop size.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
use_einsum_gather: use einsum to replace gather or not. Replacing einsum
with gather can improve performance when feature size is not large, einsum
is friendly with model partition as well. Gather's performance is better
when feature size is very large and there are multiple box levels.
Returns:
features_per_box: a 5-D tensor of shape
[batch_size, num_boxes, output_size, output_size, num_filters]
representing the cropped features.
"""
(batch_size, num_levels, max_feature_height, max_feature_width,
num_filters) = features.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(features)[0]
_, num_boxes, _ = boxes.get_shape().as_list()
kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = _compute_grid_positions(
boxes, boundaries, output_size, sample_offset)
x_indices = tf.cast(
tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
y_indices = tf.cast(
tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
dtype=tf.int32)
if use_einsum_gather:
# Blinear interpolation is done during the last two gathers:
# f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
# [f10, f11]]
# [[f00, f01],
# [f10, f11]] = tf.einsum(tf.einsum(features, y_one_hot), x_one_hot)
# where [hy, ly] and [hx, lx] are the bilinear interpolation kernel.
y_indices = tf.cast(
tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size, 2]),
dtype=tf.int32)
x_indices = tf.cast(
tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size, 2]),
dtype=tf.int32)
# shape is [batch_size, num_boxes, output_size, 2, height]
grid_y_one_hot = tf.one_hot(
tf.cast(y_indices, tf.int32), max_feature_height, dtype=kernel_y.dtype)
# shape is [batch_size, num_boxes, output_size, 2, width]
grid_x_one_hot = tf.one_hot(
tf.cast(x_indices, tf.int32), max_feature_width, dtype=kernel_x.dtype)
# shape is [batch_size, num_boxes, output_size, height]
grid_y_weight = tf.reduce_sum(
tf.multiply(grid_y_one_hot, kernel_y), axis=-2)
# shape is [batch_size, num_boxes, output_size, width]
grid_x_weight = tf.reduce_sum(
tf.multiply(grid_x_one_hot, kernel_x), axis=-2)
# Gather for y_axis.
# shape is [batch_size, num_boxes, output_size, width, features]
features_per_box = tf.einsum('bmhwf,bmoh->bmowf', features,
tf.cast(grid_y_weight, features.dtype))
# Gather for x_axis.
# shape is [batch_size, num_boxes, output_size, output_size, features]
features_per_box = tf.einsum('bmhwf,bmow->bmhof', features_per_box,
tf.cast(grid_x_weight, features.dtype))
else:
height_dim_offset = max_feature_width
level_dim_offset = max_feature_height * height_dim_offset
batch_dim_offset = num_levels * level_dim_offset
batch_size_offset = tf.tile(
tf.reshape(
tf.range(batch_size) * batch_dim_offset, [batch_size, 1, 1, 1]),
[1, num_boxes, output_size * 2, output_size * 2])
box_levels_offset = tf.tile(
tf.reshape(box_levels * level_dim_offset,
[batch_size, num_boxes, 1, 1]),
[1, 1, output_size * 2, output_size * 2])
y_indices_offset = tf.tile(
tf.reshape(y_indices * height_dim_offset,
[batch_size, num_boxes, output_size * 2, 1]),
[1, 1, 1, output_size * 2])
x_indices_offset = tf.tile(
tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
[1, 1, output_size * 2, 1])
indices = tf.reshape(
batch_size_offset + box_levels_offset + y_indices_offset +
x_indices_offset, [-1])
features = tf.reshape(features, [-1, num_filters])
# TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
# performance.
features_per_box = tf.reshape(
tf.gather(features, indices),
[batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
features_per_box = _feature_bilinear_interpolation(
features_per_box, kernel_y, kernel_x)
return features_per_box
def crop_mask_in_target_box(masks,
boxes,
target_boxes,
output_size,
sample_offset=0,
use_einsum=True):
"""Crop masks in target boxes.
Args:
masks: A tensor with a shape of [batch_size, num_masks, height, width].
boxes: a float tensor representing box cooridnates that tightly enclose
masks with a shape of [batch_size, num_masks, 4] in un-normalized
coordinates. A box is represented by [ymin, xmin, ymax, xmax].
target_boxes: a float tensor representing target box cooridnates for masks
with a shape of [batch_size, num_masks, 4] in un-normalized coordinates. A
box is represented by [ymin, xmin, ymax, xmax].
output_size: A scalar to indicate the output crop size. It currently only
supports to output a square shape outputs.
sample_offset: a float number in [0, 1] indicates the subpixel sample offset
from grid point.
use_einsum: Use einsum to replace gather in selective_crop_and_resize.
Returns:
A 4-D tensor representing feature crop of shape
[batch_size, num_boxes, output_size, output_size].
"""
with tf.name_scope('crop_mask_in_target_box'):
# Cast to float32, as the y_transform and other transform variables may
# overflow in float16
masks = tf.cast(masks, tf.float32)
boxes = tf.cast(boxes, tf.float32)
target_boxes = tf.cast(target_boxes, tf.float32)
batch_size, num_masks, height, width = masks.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(masks)[0]
masks = tf.reshape(masks, [batch_size * num_masks, height, width, 1])
# Pad zeros on the boundary of masks.
masks = tf.image.pad_to_bounding_box(masks, 2, 2, height + 4, width + 4)
masks = tf.reshape(masks, [batch_size, num_masks, height+4, width+4, 1])
# Projects target box locations and sizes to corresponding cropped
# mask coordinates.
gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
value=boxes, num_or_size_splits=4, axis=2)
bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
value=target_boxes, num_or_size_splits=4, axis=2)
y_transform = (bb_y_min - gt_y_min) * height / (
gt_y_max - gt_y_min + _EPSILON) + 2
x_transform = (bb_x_min - gt_x_min) * height / (
gt_x_max - gt_x_min + _EPSILON) + 2
h_transform = (bb_y_max - bb_y_min) * width / (
gt_y_max - gt_y_min + _EPSILON)
w_transform = (bb_x_max - bb_x_min) * width / (
gt_x_max - gt_x_min + _EPSILON)
boundaries = tf.concat(
[tf.ones_like(y_transform) * ((height + 4) - 1),
tf.ones_like(x_transform) * ((width + 4) - 1)],
axis=-1)
boundaries = tf.cast(boundaries, dtype=y_transform.dtype)
# Reshape tensors to have the right shape for selective_crop_and_resize.
trasnformed_boxes = tf.concat(
[y_transform, x_transform, h_transform, w_transform], -1)
levels = tf.tile(tf.reshape(tf.range(num_masks), [1, num_masks]),
[batch_size, 1])
cropped_masks = _selective_crop_and_resize(
masks,
trasnformed_boxes,
levels,
boundaries,
output_size,
sample_offset=sample_offset,
use_einsum_gather=use_einsum)
cropped_masks = tf.squeeze(cropped_masks, axis=-1)
return cropped_masks
def nearest_upsampling(data, scale, use_keras_layer=False):
"""Nearest neighbor upsampling implementation.
Args:
data: A tensor with a shape of [batch, height_in, width_in, channels].
scale: An integer multiple to scale resolution of input data.
use_keras_layer: If True, use keras Upsampling2D layer.
Returns:
data_up: A tensor with a shape of
[batch, height_in*scale, width_in*scale, channels]. Same dtype as input
data.
"""
if use_keras_layer:
return tf.keras.layers.UpSampling2D(size=(scale, scale),
interpolation='nearest')(data)
with tf.name_scope('nearest_upsampling'):
bs, _, _, c = data.get_shape().as_list()
shape = tf.shape(input=data)
h = shape[1]
w = shape[2]
bs = -1 if bs is None else bs
# Uses reshape to quickly upsample the input. The nearest pixel is selected
# via tiling.
data = tf.tile(
tf.reshape(data, [bs, h, 1, w, 1, c]), [1, 1, scale, 1, scale, 1])
return tf.reshape(data, [bs, h * scale, w * scale, c])
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of target gather, which gathers targets from indices."""
import tensorflow as tf
class TargetGather:
"""Targer gather for dense object detector."""
def __call__(self, labels, match_indices, mask=None, mask_val=0.0):
"""Labels anchors with ground truth inputs.
B: batch_size
N: number of groundtruth boxes.
Args:
labels: An integer tensor with shape [N, dims] or [B, N, ...] representing
groundtruth labels.
match_indices: An integer tensor with shape [M] or [B, M] representing
match label index.
mask: An boolean tensor with shape [M, dims] or [B, M,...] representing
match labels.
mask_val: An integer to fill in for mask.
Returns:
target: An integer Tensor with shape [M] or [B, M]
Raises:
ValueError: If `labels` is higher than rank 3.
"""
if len(labels.shape) <= 2:
return self._gather_unbatched(labels, match_indices, mask, mask_val)
elif len(labels.shape) == 3:
return self._gather_batched(labels, match_indices, mask, mask_val)
else:
raise ValueError("`TargetGather` does not support `labels` with rank "
"larger than 3, got {}".format(len(labels.shape)))
def _gather_unbatched(self, labels, match_indices, mask, mask_val):
"""Gather based on unbatched labels and boxes."""
num_gt_boxes = tf.shape(labels)[0]
def _assign_when_rows_empty():
if len(labels.shape) > 1:
mask_shape = [match_indices.shape[0], labels.shape[-1]]
else:
mask_shape = [match_indices.shape[0]]
return tf.cast(mask_val, labels.dtype) * tf.ones(
mask_shape, dtype=labels.dtype)
def _assign_when_rows_not_empty():
targets = tf.gather(labels, match_indices)
if mask is None:
return targets
else:
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype)
return tf.where(mask, masked_targets, targets)
return tf.cond(tf.greater(num_gt_boxes, 0),
_assign_when_rows_not_empty,
_assign_when_rows_empty)
def _gather_batched(self, labels, match_indices, mask, mask_val):
"""Gather based on batched labels."""
batch_size = labels.shape[0]
if batch_size == 1:
if mask is not None:
result = self._gather_unbatched(
tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
tf.squeeze(mask, axis=0), mask_val)
else:
result = self._gather_unbatched(
tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
None, mask_val)
return tf.expand_dims(result, axis=0)
else:
indices_shape = tf.shape(match_indices)
indices_dtype = match_indices.dtype
batch_indices = (tf.expand_dims(
tf.range(indices_shape[0], dtype=indices_dtype), axis=-1) *
tf.ones([1, indices_shape[-1]], dtype=indices_dtype))
gather_nd_indices = tf.stack(
[batch_indices, match_indices], axis=-1)
targets = tf.gather_nd(labels, gather_nd_indices)
if mask is None:
return targets
else:
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype)
return tf.where(mask, masked_targets, targets)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for target_gather.py."""
import tensorflow as tf
from official.vision.ops import target_gather
class TargetGatherTest(tf.test.TestCase):
def test_target_gather_batched(self):
gt_boxes = tf.constant(
[[
[0, 0, 5, 5],
[0, 5, 5, 10],
[5, 0, 10, 5],
[5, 5, 10, 10],
]],
dtype=tf.float32)
gt_classes = tf.constant([[[2], [10], [3], [-1]]], dtype=tf.int32)
labeler = target_gather.TargetGather()
match_indices = tf.constant([[0, 2]], dtype=tf.int32)
match_indicators = tf.constant([[-2, 1]])
mask = tf.less_equal(match_indicators, 0)
cls_mask = tf.expand_dims(mask, -1)
matched_gt_classes = labeler(gt_classes, match_indices, cls_mask)
box_mask = tf.tile(cls_mask, [1, 1, 4])
matched_gt_boxes = labeler(gt_boxes, match_indices, box_mask)
self.assertAllEqual(
matched_gt_classes.numpy(), [[[0], [3]]])
self.assertAllClose(
matched_gt_boxes.numpy(), [[[0, 0, 0, 0], [5, 0, 10, 5]]])
def test_target_gather_unbatched(self):
gt_boxes = tf.constant(
[
[0, 0, 5, 5],
[0, 5, 5, 10],
[5, 0, 10, 5],
[5, 5, 10, 10],
],
dtype=tf.float32)
gt_classes = tf.constant([[2], [10], [3], [-1]], dtype=tf.int32)
labeler = target_gather.TargetGather()
match_indices = tf.constant([0, 2], dtype=tf.int32)
match_indicators = tf.constant([-2, 1])
mask = tf.less_equal(match_indicators, 0)
cls_mask = tf.expand_dims(mask, -1)
matched_gt_classes = labeler(gt_classes, match_indices, cls_mask)
box_mask = tf.tile(cls_mask, [1, 4])
matched_gt_boxes = labeler(gt_boxes, match_indices, box_mask)
self.assertAllEqual(
matched_gt_classes.numpy(), [[0], [3]])
self.assertAllClose(
matched_gt_boxes.numpy(), [[0, 0, 0, 0], [5, 0, 10, 5]])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official import vision
from official.nlp import tasks
from official.nlp.configs import experiment_configs
from official.utils.testing import mock_task
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Detection input and model functions for serving/inference."""
from typing import Mapping, Text
import tensorflow as tf
from official.vision import configs
from official.vision.modeling import factory
from official.vision.ops import anchor
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class DetectionModule(export_base.ExportModule):
"""Detection Module."""
def _build_model(self):
if self._batch_size is None:
raise ValueError('batch_size cannot be None for detection models.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
model = factory.build_maskrcnn(
input_specs=input_specs, model_config=self.params.task.model)
elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
model = factory.build_retinanet(
input_specs=input_specs, model_config=self.params.task.model)
else:
raise ValueError('Detection module not implemented for {} model.'.format(
type(self.params.task.model)))
return model
def _build_anchor_boxes(self):
"""Builds and returns anchor boxes."""
model_params = self.params.task.model
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
return input_anchor(
image_size=(self._input_image_size[0], self._input_image_size[1]))
def _build_inputs(self, image):
"""Builds detection model inputs for serving."""
model_params = self.params.task.model
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=preprocess_ops.compute_padded_size(
self._input_image_size, 2**model_params.max_level),
aug_scale_min=1.0,
aug_scale_max=1.0)
anchor_boxes = self._build_anchor_boxes()
return image, anchor_boxes, image_info
def preprocess(self, images: tf.Tensor) -> (
tf.Tensor, Mapping[Text, tf.Tensor], tf.Tensor):
"""Preprocess inputs to be suitable for the model.
Args:
images: The images tensor.
Returns:
images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
"""
model_params = self.params.task.model
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
images_spec = tf.TensorSpec(shape=self._input_image_size + [3],
dtype=tf.float32)
num_anchors = model_params.anchor.num_scales * len(
model_params.anchor.aspect_ratios) * 4
anchor_shapes = []
for level in range(model_params.min_level, model_params.max_level + 1):
anchor_level_spec = tf.TensorSpec(
shape=[
self._input_image_size[0] // 2**level,
self._input_image_size[1] // 2**level, num_anchors
],
dtype=tf.float32)
anchor_shapes.append((str(level), anchor_level_spec))
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, anchor_boxes, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=(images_spec, dict(anchor_shapes),
image_info_spec),
parallel_iterations=32))
return images, anchor_boxes, image_info
def serve(self, images: tf.Tensor):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
images, anchor_boxes, image_info = self.preprocess(images)
else:
with tf.device('cpu:0'):
anchor_boxes = self._build_anchor_boxes()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info = tf.convert_to_tensor([[
self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
]],
dtype=tf.float32)
input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that
# have multiple inputs, we use `model.call` here to trigger the forward
# path. Note that, this disables some keras magics happens in `__call__`.
detections = self.model.call(
images=images,
image_shape=input_image_shape,
anchor_boxes=anchor_boxes,
training=False)
if self.params.task.model.detection_generator.apply_nms:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
detection_boxes = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
detections['detection_boxes'] = box_ops.normalize_boxes(
detection_boxes, image_info[:, 0:1, :])
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if export_config.cast_num_detections_to_float:
detections['num_detections'] = tf.cast(
detections['num_detections'], dtype=tf.float32)
if export_config.cast_detection_classes_to_float:
detections['detection_classes'] = tf.cast(
detections['detection_classes'], dtype=tf.float32)
final_outputs = {
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections']
}
else:
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
}
if 'detection_masks' in detections.keys():
final_outputs['detection_masks'] = detections['detection_masks']
final_outputs.update({'image_info': image_info})
return final_outputs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Test for image detection export lib."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import exp_factory
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_detection_module(self, experiment_name, input_type):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.nms_version = 'batched'
detection_module = detection.DetectionModule(
params,
batch_size=1,
input_image_size=[640, 640],
input_type=input_type)
return detection_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type."""
h, w = image_size
if input_type == 'image_tensor':
return tf.zeros((batch_size, h, w, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue() for b in range(batch_size)]
elif input_type == 'tf_example':
image_tensor = tf.zeros((h, w, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example for b in range(batch_size)]
elif input_type == 'tflite':
return tf.zeros((batch_size, h, w, 3), dtype=np.float32)
@parameterized.parameters(
('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]),
('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('tflite', 'retinanet_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]),
('tflite', 'retinanet_spinenet_coco', [640, 640]),
)
def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(experiment_name, input_type)
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
if input_type == 'tflite':
processed_images = tf.zeros(image_size + [3], dtype=tf.float32)
anchor_boxes = module._build_anchor_boxes()
image_info = tf.convert_to_tensor(
[image_size, image_size, [1.0, 1.0], [0, 0]], dtype=tf.float32)
else:
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :]
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = module.model(
images=processed_images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
training=False)
outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex(
ValueError, 'batch_size cannot be None for detection models.'):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Base class for model export."""
import abc
from typing import Dict, List, Mapping, Optional, Text
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self,
params: cfg.ExperimentConfig,
*,
batch_size: int,
input_image_size: List[int],
input_type: str = 'image_tensor',
num_channels: int = 3,
model: Optional[tf.keras.Model] = None):
"""Initializes a module for export.
Args:
params: Experiment params.
batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
input_type: The input signature type.
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported.
"""
self.params = params
self._batch_size = batch_size
self._input_image_size = input_image_size
self._num_channels = num_channels
self._input_type = input_type
if model is None:
model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model)
def _decode_image(self, encoded_image_bytes: str) -> tf.Tensor:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
Returns:
A decoded image tensor.
"""
if len(self._input_image_size) == 2:
# Decode an image if 2D input is expected.
image_tensor = tf.image.decode_image(
encoded_image_bytes, channels=self._num_channels)
image_tensor.set_shape((None, None, self._num_channels))
else:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)
image_tensor = tf.reshape(image_tensor,
self._input_image_size + [self._num_channels])
return image_tensor
def _decode_tf_example(
self, tf_example_string_tensor: tf.train.Example) -> tf.Tensor:
"""Decodes a TF Example to an image tensor.
Args:
tf_example_string_tensor: A tf.train.Example of encoded image and other
information.
Returns:
A decoded image tensor.
"""
keys_to_features = {'image/encoded': tf.io.FixedLenFeature((), tf.string)}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = self._decode_image(parsed_tensors['image/encoded'])
return image_tensor
def _build_model(self, **kwargs):
"""Returns a model built from the params."""
return None
@tf.function
def inference_from_image_tensors(
self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_for_tflite(self, inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(inputs)
@tf.function
def inference_from_image_bytes(self, inputs: tf.Tensor):
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._decode_image,
elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
parallel_iterations=32))
images = tf.stack(images)
return self.serve(images)
@tf.function
def inference_from_tf_example(self,
inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._decode_tf_example,
elems=inputs,
# Height/width of the shape of input images is unspecified (None)
# at the time of decoding the example, but the shape will
# be adjusted to conform to the input layer of the model,
# by _run_inference_on_image_tensors() below.
fn_output_signature=tf.TensorSpec(
shape=[None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8),
dtype=tf.uint8,
parallel_iterations=32))
images = tf.stack(images)
return self.serve(images)
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8)
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'image_bytes':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_image_bytes.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
elif key == 'tflite':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size +
[self._num_channels],
dtype=tf.float32)
signatures[def_name] = self.inference_for_tflite.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
from typing import Dict, Optional, Text, Callable, Any, Union
import tensorflow as tf
from official.core import export_base
class ExportModule(export_base.ExportModule):
"""Base Export Module."""
def __init__(self,
params,
model: tf.keras.Model,
input_signature: Union[tf.TensorSpec, Dict[str, tf.TensorSpec]],
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Initializes a module for export.
Args:
params: A dataclass for parameters to the module.
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(
params,
model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature
@tf.function
def serve(self, inputs):
x = self.preprocessor(inputs=inputs) if self.preprocessor else inputs
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return x
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for _, def_name in function_keys.items():
signatures[def_name] = self.serve.get_concrete_function(
self.input_signature)
return signatures
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base_v2."""
import os
import tensorflow as tf
from official.core import export_base
from official.vision.serving import export_base_v2
class TestModel(tf.keras.Model):
def __init__(self):
super().__init__()
self._dense = tf.keras.layers.Dense(2)
def call(self, inputs):
return {'outputs': self._dense(inputs)}
class ExportBaseTest(tf.test.TestCase):
def test_preprocessor(self):
tmp_dir = self.get_temp_dir()
model = TestModel()
inputs = tf.ones([2, 4], tf.float32)
preprocess_fn = lambda inputs: 2 * inputs
module = export_base_v2.ExportModule(
params=None,
input_signature=tf.TensorSpec(shape=[2, 4]),
model=model,
preprocessor=preprocess_fn)
expected_output = model(preprocess_fn(inputs))
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['serving_default'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['serving_default'](inputs)
print('output', output)
self.assertAllClose(
output['outputs'].numpy(), expected_output['outputs'].numpy())
def test_postprocessor(self):
tmp_dir = self.get_temp_dir()
model = TestModel()
inputs = tf.ones([2, 4], tf.float32)
postprocess_fn = lambda logits: {'outputs': 2 * logits['outputs']}
module = export_base_v2.ExportModule(
params=None,
model=model,
input_signature=tf.TensorSpec(shape=[2, 4]),
postprocessor=postprocess_fn)
expected_output = postprocess_fn(model(inputs))
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['serving_default'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['serving_default'](inputs)
self.assertAllClose(
output['outputs'].numpy(), expected_output['outputs'].numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for vision export modules."""
from typing import List, Optional
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision.dataloaders import classification_input
from official.vision.modeling import factory
from official.vision.serving import export_base_v2 as export_base
from official.vision.serving import export_utils
def create_classification_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3):
"""Creats classification export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
input_specs = tf.keras.layers.InputSpec(
shape=[batch_size] + input_image_size + [num_channels])
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None)
def preprocess_fn(inputs):
image_tensor = export_utils.parse_image(inputs, input_type,
input_image_size, num_channels)
# If input_type is `tflite`, do not apply image preprocessing.
if input_type == 'tflite':
return image_tensor
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(
inputs, input_image_size, num_channels)
images = tf.map_fn(
preprocess_image_fn, elems=image_tensor,
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [num_channels],
dtype=tf.float32))
return images
def postprocess_fn(logits):
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
export_module = export_base.ExportModule(params,
model=model,
input_signature=input_signature,
preprocessor=preprocess_fn,
postprocessor=postprocess_fn)
return export_module
def get_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
"""Factory for export modules."""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = create_classification_export_module(
params, input_type, batch_size, input_image_size, num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
return export_module
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for vision modules."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import exp_factory
from official.core import export_base
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.dataloaders import classification_input
from official.vision.serving import export_module_factory
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self, input_type, input_image_size):
params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18
module = export_module_factory.create_classification_export_module(
params, input_type, batch_size=1, input_image_size=input_image_size)
return module
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return tf.zeros((1, 32, 32, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((32, 32, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type='image_tensor'):
input_image_size = [32, 32]
tmp_dir = self.get_temp_dir()
module = self._get_classification_module(input_type, input_image_size)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
ckpt_path = tf.train.Checkpoint(model=module.model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, [input_type],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(export_dir)
classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(
inputs, input_image_size, num_channels=3)
processed_images = tf.map_fn(
preprocess_image_fn,
elems=tf.zeros([1] + input_image_size + [3], dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [3], dtype=tf.float32))
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(
out['logits'].numpy(), expected_logits.numpy(), rtol=1e-04, atol=1e-04)
self.assertAllClose(
out['probs'].numpy(), expected_prob.numpy(), rtol=1e-04, atol=1e-04)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
flags.DEFINE_string('export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_checkpoint_subdir=FLAGS.export_checkpoint_subdir,
export_saved_model_subdir=FLAGS.export_saved_model_subdir)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision import configs
from official.vision.serving import detection
from official.vision.serving import image_classification
from official.vision.serving import semantic_segmentation
from official.vision.serving import video_classification
def export_inference_graph(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
and model parameters to model_params.txt.
"""
if export_checkpoint_subdir:
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
export_dir, export_saved_model_subdir)
else:
output_saved_model_directory = export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if not export_module:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task,
configs.video_classification.VideoClassificationTask):
export_module = video_classification.VideoClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
export_base.export(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
if log_model_flops_and_params:
inputs_kwargs = None
if isinstance(
params.task,
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
inputs_kwargs = {
'images':
tf.TensorSpec([1] + input_image_size + [num_channels],
tf.float32),
'image_shape':
tf.TensorSpec([1, 2], tf.float32)
}
dummy_inputs = {
k: tf.ones(v.shape.as_list(), tf.float32)
for k, v in inputs_kwargs.items()
}
# Must do forward pass to build the model.
export_module.model(**dummy_inputs)
else:
logging.info(
'Logging model flops and params not implemented for %s task.',
type(params.task))
return
train_utils.try_count_flops(export_module.model, inputs_kwargs,
os.path.join(export_dir, 'model_flops.txt'))
train_utils.write_model_params(export_module.model,
os.path.join(export_dir, 'model_params.txt'))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_saved_model_lib."""
import os
from unittest import mock
import tensorflow as tf
from official.core import export_base
from official.vision import configs
from official.vision.serving import export_saved_model_lib
class WriteModelFlopsAndParamsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.tempdir = self.create_tempdir()
self.enter_context(
mock.patch.object(export_base, 'export', autospec=True, spec_set=True))
def _export_model_with_log_model_flops_and_params(self, params):
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[64, 64],
params=params,
checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'),
export_dir=self.tempdir,
log_model_flops_and_params=True)
def assertModelAnalysisFilesExist(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_params.txt')))
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_flops.txt')))
def test_retinanet_task(self):
params = configs.retinanet.retinanet_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
def test_maskrcnn_task(self):
params = configs.maskrcnn.maskrcnn_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision.serving import export_module_factory
def export(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None):
"""Exports the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params
to create one. If None, the params will be used to create an export
module.
export_checkpoint_subdir: Optional subdirectory under export_dir
to store checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
"""
if export_checkpoint_subdir:
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
export_dir, export_saved_model_subdir)
else:
output_saved_model_directory = export_dir
export_module = export_module_factory.get_export_module(
params,
input_type=input_type,
batch_size=batch_size,
input_image_size=input_image_size,
num_channels=num_channels)
export_base.export(
export_module,
function_keys=[input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""A script to export the image classification as a TF-Hub SavedModel."""
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.modeling import factory
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. resnet_imagenet')
flags.DEFINE_string(
'checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_string(
'export_path', None, 'The export directory.')
flags.DEFINE_multi_string(
'config_file',
None,
'A YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_image_size',
'224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_boolean(
'skip_logits_layer',
False,
'Whether to skip the prediction layer and only output the feature vector.')
def export_model_to_tfhub(params,
batch_size,
input_image_size,
skip_logits_layer,
checkpoint_path,
export_path):
"""Export an image classification model to TF-Hub."""
input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
input_image_size + [3])
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
model.save(export_path, include_optimizer=False, save_format='tf')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_model_to_tfhub(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
skip_logits_layer=FLAGS.skip_logits_layer,
checkpoint_path=FLAGS.checkpoint_path,
export_path=FLAGS.export_path)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Binary to convert a saved model to tflite model.
It requires a SavedModel exported using export_saved_model.py with batch size 1
and input type `tflite`, and using the same config file used for exporting saved
model. It includes optional post-training quantization. When using integer
quantization, calibration steps need to be provided to calibrate model input.
To convert a SavedModel to a TFLite model:
EXPERIMENT_TYPE = XX
TFLITE_PATH = XX
SAVED_MOODEL_DIR = XX
CONFIG_FILE = XX
export_tflite --experiment=${EXPERIMENT_TYPE} \
--saved_model_dir=${SAVED_MOODEL_DIR} \
--tflite_path=${TFLITE_PATH} \
--config_file=${CONFIG_FILE} \
--quant_type=fp16 \
--calibration_steps=500
"""
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import export_tflite_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment',
None,
'experiment type, e.g. retinanet_resnetfpn_coco',
required=True)
flags.DEFINE_multi_string(
'config_file',
default='',
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_string(
'saved_model_dir', None, 'The directory to the saved model.', required=True)
flags.DEFINE_string(
'tflite_path', None, 'The path to the output tflite model.', required=True)
flags.DEFINE_string(
'quant_type',
default=None,
help='Post training quantization type. Support `int8`, `int8_full`, '
'`fp16`, and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.')
flags.DEFINE_integer('calibration_steps', 500,
'The number of calibration steps for integer model.')
def main(_) -> None:
params = exp_factory.get_exp_config(FLAGS.experiment)
if FLAGS.config_file is not None:
for config_file in FLAGS.config_file:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
logging.info('Converting SavedModel from %s to TFLite model...',
FLAGS.saved_model_dir)
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=FLAGS.saved_model_dir,
quant_type=FLAGS.quant_type,
params=params,
calibration_steps=FLAGS.calibration_steps)
with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw:
fw.write(tflite_model)
logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import functools
from typing import Iterator, List, Optional
from absl import logging
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision import tasks
def create_representative_dataset(
params: cfg.ExperimentConfig) -> tf.data.Dataset:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
Returns:
A tf.data.Dataset instance.
Raises:
ValueError: If task is not supported.
"""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
task = tasks.maskrcnn.MaskRCNNTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
else:
raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32'
logging.info('Task config: %s', params.task.as_dict())
return task.build_inputs(params=params.task.train_data)
def representative_dataset(
params: cfg.ExperimentConfig,
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset = create_representative_dataset(params=params)
for image, _ in dataset.take(calibration_steps):
# Skip images that do not have 3 channels.
if image.shape[-1] != 3:
continue
yield [image]
def convert_tflite_model(saved_model_dir: str,
quant_type: Optional[str] = None,
params: Optional[cfg.ExperimentConfig] = None,
calibration_steps: Optional[int] = 2000) -> bytes:
"""Converts and returns a TFLite model.
Args:
saved_model_dir: The directory to the SavedModel.
quant_type: The post training quantization (PTQ) method. It can be one of
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
calibration_steps: The steps to do calibration.
Returns:
A converted TFLite model with optional PTQ.
Raises:
ValueError: If `representative_dataset_path` is not present if integer
quantization is requested.
"""
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
if quant_type:
if quant_type.startswith('int8'):
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = functools.partial(
representative_dataset,
params=params,
calibration_steps=calibration_steps)
if quant_type == 'int8_full':
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
elif quant_type == 'default':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
elif quant_type == 'qat':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
else:
raise ValueError(f'quantization type {quant_type} is not supported.')
return converter.convert()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_tflite_lib."""
import os
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.core import exp_factory
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.dataloaders import tfexample_utils
from official.vision.serving import detection as detection_serving
from official.vision.serving import export_tflite_lib
from official.vision.serving import image_classification as image_classification_serving
from official.vision.serving import semantic_segmentation as semantic_segmentation_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
def _export_from_module(self, module, input_type, saved_model_dir):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, saved_model_dir, signatures=signatures)
@combinations.generate(
combinations.combine(
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[224, 224]]))
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[384, 384]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment