Commit 1a3c83d6 authored by zhanggzh's avatar zhanggzh
Browse files

增加keras-cv模型及训练代码

parent 9846958a
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"Argmax-based box matching"
from typing import List
from typing import Tuple
import tensorflow as tf
class ArgmaxBoxMatcher:
"""Box matching logic based on argmax of highest value (e.g., IOU).
This class computes matches from a similarity matrix. Each row will be
matched to at least one column, the matched result can either be positive
/ negative, or simply ignored depending on the setting.
The settings include `thresholds` and `match_values`, for example if:
1) thresholds=[negative_threshold, positive_threshold], and
match_values=[negative_value=0, ignore_value=-1, positive_value=1]: the rows will
be assigned to positive_value if its argmax result >=
positive_threshold; the rows will be assigned to negative_value if its
argmax result < negative_threshold, and the rows will be assigned
to ignore_value if its argmax result is between [negative_threshold, positive_threshold).
2) thresholds=[negative_threshold, positive_threshold], and
match_values=[ignore_value=-1, negative_value=0, positive_value=1]: the rows will
be assigned to positive_value if its argmax result >=
positive_threshold; the rows will be assigned to ignore_value if its
argmax result < negative_threshold, and the rows will be assigned
to negative_value if its argmax result is between [negative_threshold ,positive_threshold).
This is different from case 1) by swapping first two
values.
3) thresholds=[positive_threshold], and
match_values=[negative_values, positive_value]: the rows will be assigned to
positive value if its argmax result >= positive_threshold; the rows
will be assigned to negative_value if its argmax result < negative_threshold.
Args:
thresholds: A sorted list of floats to classify the matches into
different results (e.g. positive or negative or ignored match). The
list will be prepended with -Inf and and appended with +Inf.
match_values: A list of integers representing matched results (e.g.
positive or negative or ignored match). len(`match_values`) must
equal to len(`thresholds`) + 1.
force_match_for_each_col: each row will be argmax matched to at
least one column. This means some columns will be matched to
multiple rows while some columns will not be matched to any rows.
Filtering by `thresholds` will make less columns match to positive
result. Setting this to True guarantees that each column will be
matched to positive result to at least one row.
Raises:
ValueError: if `thresholds` not sorted or
len(`match_values`) != len(`thresholds`) + 1
Usage:
```python
box_matcher = keras_cv.ops.ArgmaxBoxMatcher([0.3, 0.7], [-1, 0, 1])
iou_metric = keras_cv.bounding_box.compute_iou(anchors, gt_boxes)
matched_columns, matched_match_values = box_matcher(iou_metric)
cls_mask = tf.less_equal(matched_match_values, 0)
```
TODO(tanzhenyu): document when to use which mode.
"""
def __init__(
self,
thresholds: List[float],
match_values: List[int],
force_match_for_each_col: bool = False,
):
if sorted(thresholds) != thresholds:
raise ValueError(f"`threshold` must be sorted, got {thresholds}")
self.match_values = match_values
if len(match_values) != len(thresholds) + 1:
raise ValueError(
f"len(`match_values`) must be len(`thresholds`) + 1, got "
f"match_values {match_values}, thresholds {thresholds}"
)
thresholds.insert(0, -float("inf"))
thresholds.append(float("inf"))
self.thresholds = thresholds
self.force_match_for_each_col = force_match_for_each_col
def __call__(self, similarity_matrix: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Matches each row to a column based on argmax
TODO(tanzhenyu): consider swapping rows and cols.
Args:
similarity_matrix: A float Tensor of shape [num_rows, num_cols] or
[batch_size, num_rows, num_cols] representing any similarity metric.
Returns:
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched colum for each row.
matched_values: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match result (positive match, negative match,
ignored match).
"""
squeeze_result = False
if len(similarity_matrix.shape) == 2:
squeeze_result = True
similarity_matrix = tf.expand_dims(similarity_matrix, axis=0)
static_shape = similarity_matrix.shape.as_list()
num_rows = static_shape[1] or tf.shape(similarity_matrix)[1]
batch_size = static_shape[0] or tf.shape(similarity_matrix)[0]
def _match_when_cols_are_empty():
"""Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the rows do not match to any columns.
Returns:
matched_columns: An integer tensor of shape [batch_size, num_rows]
storing the index of the matched column for each row.
matched_values: An integer tensor of shape [batch_size, num_rows]
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope("empty_gt_boxes"):
matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
matched_values = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matched_columns, matched_values
def _match_when_cols_are_non_empty():
"""Performs matching when the rows of similarity matrix are non empty.
Returns:
matched_columns: An integer tensor of shape [batch_size, num_rows]
storing the index of the matched column for each row.
matched_values: An integer tensor of shape [batch_size, num_rows]
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope("non_empty_gt_boxes"):
matched_columns = tf.argmax(
similarity_matrix, axis=-1, output_type=tf.int32
)
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
matched_values = tf.zeros([batch_size, num_rows], tf.int32)
match_dtype = matched_vals.dtype
for (ind, low, high) in zip(
self.match_values, self.thresholds[:-1], self.thresholds[1:]
):
low_threshold = tf.cast(low, match_dtype)
high_threshold = tf.cast(high, match_dtype)
mask = tf.logical_and(
tf.greater_equal(matched_vals, low_threshold),
tf.less(matched_vals, high_threshold),
)
matched_values = self._set_values_using_indicator(
matched_values, mask, ind
)
if self.force_match_for_each_col:
# [batch_size, num_cols], for each column (groundtruth_box), find the
# best matching row (anchor).
matching_rows = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32
)
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M,
# where M[j, i] = 1 means column j is matched to row i.
column_to_row_match_mapping = tf.one_hot(
matching_rows, depth=num_rows
)
# [batch_size, num_rows], for each row (anchor), find the matched
# column (groundtruth_box).
force_matched_columns = tf.argmax(
input=column_to_row_match_mapping, axis=1, output_type=tf.int32
)
# [batch_size, num_rows]
force_matched_column_mask = tf.cast(
tf.reduce_max(column_to_row_match_mapping, axis=1), tf.bool
)
# [batch_size, num_rows]
matched_columns = tf.where(
force_matched_column_mask,
force_matched_columns,
matched_columns,
)
matched_values = tf.where(
force_matched_column_mask,
self.match_values[-1]
* tf.ones([batch_size, num_rows], dtype=tf.int32),
matched_values,
)
return matched_columns, matched_values
num_gt_boxes = (
similarity_matrix.shape.as_list()[-1] or tf.shape(similarity_matrix)[-1]
)
matched_columns, matched_values = tf.cond(
pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_cols_are_non_empty,
false_fn=_match_when_cols_are_empty,
)
if squeeze_result:
matched_columns = tf.squeeze(matched_columns, axis=0)
matched_values = tf.squeeze(matched_values, axis=0)
return matched_columns, matched_values
def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val.
Args:
x: tensor.
indicator: boolean with same shape as x.
val: scalar with value to set.
Returns:
modified tensor.
"""
indicator = tf.cast(indicator, x.dtype)
return tf.add(tf.multiply(x, 1 - indicator), val * indicator)
def get_config(self):
config = {
"thresholds": self.thresholds[1:-1],
"match_values": self.match_values,
"force_match_for_each_col": self.force_match_for_each_col,
}
return config
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher
class ArgmaxBoxMatcherTest(tf.test.TestCase):
def test_box_matcher_invalid_length(self):
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0
with self.assertRaisesRegex(ValueError, "must be len"):
_ = ArgmaxBoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1],
)
def test_box_matcher_unsorted_thresholds(self):
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0
with self.assertRaisesRegex(ValueError, "must be sorted"):
_ = ArgmaxBoxMatcher(
thresholds=[bg_thresh_hi, bg_thresh_lo, fg_threshold],
match_values=[-3, -2, -1, 1],
)
def test_box_matcher_unbatched(self):
sim_matrix = tf.constant([[0.04, 0, 0, 0], [0, 0, 1.0, 0]], dtype=tf.float32)
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0
matcher = ArgmaxBoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
match_indices, matched_values = matcher(sim_matrix)
positive_matches = tf.greater_equal(matched_values, 0)
negative_matches = tf.equal(matched_values, -2)
self.assertAllEqual(positive_matches.numpy(), [False, True])
self.assertAllEqual(negative_matches.numpy(), [True, False])
self.assertAllEqual(match_indices.numpy(), [0, 2])
self.assertAllEqual(matched_values.numpy(), [-2, 1])
def test_box_matcher_batched(self):
sim_matrix = tf.constant([[[0.04, 0, 0, 0], [0, 0, 1.0, 0]]], dtype=tf.float32)
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0
matcher = ArgmaxBoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
match_indices, matched_values = matcher(sim_matrix)
positive_matches = tf.greater_equal(matched_values, 0)
negative_matches = tf.equal(matched_values, -2)
self.assertAllEqual(positive_matches.numpy(), [[False, True]])
self.assertAllEqual(negative_matches.numpy(), [[True, False]])
self.assertAllEqual(match_indices.numpy(), [[0, 2]])
self.assertAllEqual(matched_values.numpy(), [[-2, 1]])
def test_box_matcher_force_match(self):
sim_matrix = tf.constant(
[[0, 0.04, 0, 0.1], [0, 0, 1.0, 0], [0.1, 0, 0, 0], [0, 0, 0, 0.6]],
dtype=tf.float32,
)
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0
matcher = ArgmaxBoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
force_match_for_each_col=True,
)
match_indices, matched_values = matcher(sim_matrix)
positive_matches = tf.greater_equal(matched_values, 0)
negative_matches = tf.equal(matched_values, -2)
self.assertAllEqual(positive_matches.numpy(), [True, True, True, True])
self.assertAllEqual(negative_matches.numpy(), [False, False, False, False])
# the first anchor cannot be matched to 4th gt box given that is matched to
# the last anchor.
self.assertAllEqual(match_indices.numpy(), [1, 2, 0, 3])
self.assertAllEqual(matched_values.numpy(), [1, 1, 1, 1])
def test_box_matcher_empty_gt_boxes(self):
sim_matrix = tf.constant([[], []], dtype=tf.float32)
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0
matcher = ArgmaxBoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
match_indices, matched_values = matcher(sim_matrix)
positive_matches = tf.greater_equal(matched_values, 0)
ignore_matches = tf.equal(matched_values, -1)
self.assertAllEqual(positive_matches.numpy(), [False, False])
self.assertAllEqual(ignore_matches.numpy(), [True, True])
self.assertAllEqual(match_indices.numpy(), [0, 0])
self.assertAllEqual(matched_values.numpy(), [-1, -1])
# Copyright 2022 The KerasCV Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IoU3D using a custom TF op."""
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
class IoU3D:
"""Implements IoU computation for 3D upright rotated bounding boxes.
Note that this is implemented using a custom TensorFlow op. Initializing an
IoU3D object will attempt to load the binary for that op.
Boxes should have the format [center_x, center_y, center_z, dimension_x,
dimension_y, dimension_z, heading (in radians)].
Sample Usage:
```python
y_true = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]]
y_pred = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]]
iou = IoU3D()
iou(y_true, y_pred)
```
"""
def __init__(self):
pairwise_iou_op = load_library.load_op_library(
resource_loader.get_path_to_datafile(
"../custom_ops/_keras_cv_custom_ops.so"
)
)
self.iou_3d = pairwise_iou_op.pairwise_iou3d
def __call__(self, y_true, y_pred):
return self.iou_3d(y_true, y_pred)
# Copyright 2022 The KerasCV Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for IoU3D using custom op."""
import math
import os
import pytest
import tensorflow as tf
from keras_cv.ops import IoU3D
class IoU3DTest(tf.test.TestCase):
@pytest.mark.skipif(
"TEST_CUSTOM_OPS" not in os.environ or os.environ["TEST_CUSTOM_OPS"] != "true",
reason="Requires binaries compiled from source",
)
def testOpCall(self):
# Predicted boxes:
# 0: a 2x2x2 box centered at 0,0,0, rotated 0 degrees
# 1: a 2x2x2 box centered at 1,1,1, rotated 135 degrees
# Ground Truth boxes:
# 0: a 2x2x2 box centered at 1,1,1, rotated 45 degrees (idential to predicted box 1)
# 1: a 2x2x2 box centered at 1,1,1, rotated 0 degrees
box_preds = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]]
box_gt = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]]
# Predicted box 0 and both ground truth boxes overlap by 1/8th of the box.
# Therefore, IiU is 1/15
# Predicted box 1 is the same as ground truth box 0, therefore IoU is 1
# Predicted box 1 shares an origin with ground truth box 1, but is rotated by 135 degrees.
# Their IoU can be reduced to that of two overlapping squares that share a center with
# the same offset of 135 degrees, which reduces to the square root of 0.5.
expected_ious = [[1 / 15, 1 / 15], [1, 0.5**0.5]]
iou_3d = IoU3D()
self.assertAllClose(iou_3d(box_preds, box_gt), expected_ious)
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def get_rank(tensor):
return tensor.shape.ndims or tf.rank(tensor)
def _get_3d_rotation_matrix(yaw, roll, pitch):
"""Creates 3x3 rotation matrix from yaw, roll, pitch (angles in radians).
Note: Yaw -> Z, Roll -> X, Pitch -> Y
Args:
yaw: float tensor representing a yaw angle in radians.
roll: float tensor representing a roll angle in radians.
pitch: float tensor representing a pitch angle in radians.
Returns:
A [3, 3] tensor corresponding to a rotation matrix.
"""
def _UnitX(angle):
return tf.reshape(
[
1.0,
0.0,
0.0,
0.0,
tf.cos(angle),
-tf.sin(angle),
0.0,
tf.sin(angle),
tf.cos(angle),
],
shape=[3, 3],
)
def _UnitY(angle):
return tf.reshape(
[
tf.cos(angle),
0.0,
tf.sin(angle),
0.0,
1.0,
0.0,
-tf.sin(angle),
0.0,
tf.cos(angle),
],
shape=[3, 3],
)
def _UnitZ(angle):
return tf.reshape(
[
tf.cos(angle),
-tf.sin(angle),
0.0,
tf.sin(angle),
tf.cos(angle),
0.0,
0.0,
0.0,
1.0,
],
shape=[3, 3],
)
return tf.matmul(tf.matmul(_UnitZ(yaw), _UnitX(roll)), _UnitY(pitch))
def _center_xyzWHD_to_corner_xyz(boxes):
"""convert from center format to corner format.
Args:
boxes: [..., num_boxes, 7] float32 Tensor for 3d boxes in [x, y, z, dx,
dy, dz, phi].
Returns:
corners: [..., num_boxes, 8, 3] float32 Tensor for 3d corners in [x, y, z].
"""
# relative corners w.r.t to origin point
# this will return all corners in top-down counter clockwise instead of
# only left top and bottom right.
rel_corners = tf.constant(
[
[0.5, 0.5, 0.5], # top
[-0.5, 0.5, 0.5], # top
[-0.5, -0.5, 0.5], # top
[0.5, -0.5, 0.5], # top
[0.5, 0.5, -0.5], # bottom
[-0.5, 0.5, -0.5], # bottom
[-0.5, -0.5, -0.5], # bottom
[0.5, -0.5, -0.5], # bottom
]
)
centers = boxes[..., :3]
dimensions = boxes[..., 3:6]
phi_world = boxes[..., 6]
leading_shapes = boxes.shape.as_list()[:-1]
cos = tf.cos(phi_world)
sin = tf.sin(phi_world)
zero = tf.zeros_like(cos)
one = tf.ones_like(cos)
rotations = tf.reshape(
tf.stack([cos, -sin, zero, sin, cos, zero, zero, zero, one], axis=-1),
leading_shapes + [3, 3],
)
# apply the delta to convert from centers to relative corners format
rel_corners = tf.einsum("...ni,ji->...nji", dimensions, rel_corners)
# apply rotation matrix on relative corners
rel_corners = tf.einsum("...nij,...nkj->...nki", rotations, rel_corners)
# translate back to absolute corners format
corners = rel_corners + tf.reshape(centers, leading_shapes + [1, 3])
return corners
def _is_on_lefthand_side(points, v1, v2):
"""Checks if points lay on a vector direction or to its left.
Args:
point: float Tensor of [num_points, 2] of points to check
v1: float Tensor of [num_points, 2] of starting point of the vector
v2: float Tensor of [num_points, 2] of ending point of the vector
Returns:
a boolean Tensor of [num_points] indicate whether each point is on
the left of the vector or on the vector direction.
"""
# Prepare for broadcast: All point operations are on the right,
# and all v1/v2 operations are on the left. This is faster than left/right
# under the assumption that we have more points than vertices.
points_x = points[..., tf.newaxis, :, 0]
points_y = points[..., tf.newaxis, :, 1]
v1_x = v1[..., 0, tf.newaxis]
v2_x = v2[..., 0, tf.newaxis]
v1_y = v1[..., 1, tf.newaxis]
v2_y = v2[..., 1, tf.newaxis]
d1 = (points_y - v1_y) * (v2_x - v1_x)
d2 = (points_x - v1_x) * (v2_y - v1_y)
return d1 >= d2
def _box_area(boxes):
"""Compute the area of 2-d boxes.
Vertices must be ordered counter-clockwise. This function can
technically handle any kind of convex polygons.
Args:
boxes: a float Tensor of [..., 4, 2] of boxes. The last coordinates
are the four corners of the box and (x, y). The corners must be given in
counter-clockwise order.
"""
boxes_roll = tf.roll(boxes, shift=1, axis=-2)
det = (
tf.reduce_sum(
boxes[..., 0] * boxes_roll[..., 1] - boxes[..., 1] * boxes_roll[..., 0],
axis=-1,
keepdims=True,
)
/ 2.0
)
return tf.abs(det)
def is_within_box2d(points, boxes):
"""Checks if 3d points are within 2d bounding boxes.
Currently only xy format is supported.
This function returns true if points are strictly inside the box or on edge.
Args:
points: [num_points, 2] float32 Tensor for 2d points in xy format.
boxes: [num_boxes, 4, 2] float32 Tensor for 2d boxes in xy format,
counter clockwise.
Returns:
boolean Tensor of shape [num_points, num_boxes]
"""
v1, v2, v3, v4 = (
boxes[..., 0, :],
boxes[..., 1, :],
boxes[..., 2, :],
boxes[..., 3, :],
)
is_inside = tf.math.logical_and(
tf.math.logical_and(
_is_on_lefthand_side(points, v1, v2), _is_on_lefthand_side(points, v2, v3)
),
tf.math.logical_and(
_is_on_lefthand_side(points, v3, v4), _is_on_lefthand_side(points, v4, v1)
),
)
valid_area = tf.greater(_box_area(boxes), 0)
is_inside = tf.math.logical_and(is_inside, valid_area)
# swap the last two dimensions
is_inside = tf.einsum("...ij->...ji", tf.cast(is_inside, tf.int32))
return tf.cast(is_inside, tf.bool)
def is_within_box3d(points, boxes):
"""Checks if 3d points are within 3d bounding boxes.
Currently only xyz format is supported.
Args:
points: [..., num_points, 3] float32 Tensor for 3d points in xyz format.
boxes: [..., num_boxes, 7] float32 Tensor for 3d boxes in [x, y, z, dx,
dy, dz, phi].
Returns:
boolean Tensor of shape [..., num_points, num_boxes] indicating whether
the point belongs to the box.
"""
# step 1 -- determine if points are within xy range
# convert from center format to corner format
boxes_corner = _center_xyzWHD_to_corner_xyz(boxes)
# project to 2d boxes by only taking x, y on top plane
boxes_2d = boxes_corner[..., 0:4, 0:2]
# project to 2d points by only taking x, y
points_2d = points[..., :2]
# check whether points are within 2d boxes, [..., num_points, num_boxes]
is_inside_2d = is_within_box2d(points_2d, boxes_2d)
# step 2 -- determine if points are within z range
[_, _, z, _, _, dz, _] = tf.split(boxes, 7, axis=-1)
z = z[..., 0]
dz = dz[..., 0]
bottom = z - dz / 2.0
# [..., 1, num_boxes]
bottom = bottom[..., tf.newaxis, :]
top = z + dz / 2.0
top = top[..., tf.newaxis, :]
# [..., num_points, 1]
points_z = points[..., 2:]
# [..., num_points, num_boxes]
is_inside_z = tf.math.logical_and(
tf.less_equal(points_z, top), tf.greater_equal(points_z, bottom)
)
return tf.math.logical_and(is_inside_z, is_inside_2d)
def coordinate_transform(points, pose):
"""
Translate 'points' to coordinates according to 'pose' vector.
pose should contain 6 floating point values:
translate_x, translate_y, translate_z: The translation to apply.
yaw, roll, pitch: The rotation angles in radians.
Args:
points: Float shape [..., 3]: Points to transform to new coordinates.
pose: Float shape [6]: [translate_x, translate_y, translate_z, yaw, roll,
pitch]. The pose in the frame that 'points' comes from, and the definition
of the rotation and translation angles to apply to points.
Returns:
'points' transformed to the coordinates defined by 'pose'.
"""
translate_x = pose[0]
translate_y = pose[1]
translate_z = pose[2]
# Translate the points so the origin is the pose's center.
translation = tf.reshape([translate_x, translate_y, translate_z], shape=[3])
translated_points = points + translation
# Compose the rotations along the three axes.
#
# Note: Yaw->Z, Roll->X, Pitch->Y.
yaw, roll, pitch = pose[3], pose[4], pose[5]
rotation_matrix = _get_3d_rotation_matrix(yaw, roll, pitch)
# Finally, rotate the points about the pose's origin according to the
# rotation matrix.
rotated_points = tf.einsum("...i,...ij->...j", translated_points, rotation_matrix)
return rotated_points
def spherical_coordinate_transform(points):
"""Converts points from xyz coordinates to spherical coordinates.
https://en.wikipedia.org/wiki/Spherical_coordinate_system#Coordinate_system_conversions
for definitions of the transformations.
Args:
points_xyz: A floating point tensor with shape [..., 3], where the inner 3
dimensions correspond to xyz coordinates.
Returns:
A floating point tensor with the same shape [..., 3], where the inner
dimensions correspond to (dist, theta, phi), where phi corresponds to
azimuth/yaw (rotation around z), and theta corresponds to pitch/inclination
(rotation around y).
"""
dist = tf.sqrt(tf.reduce_sum(tf.square(points), axis=-1))
theta = tf.acos(points[..., 2] / tf.maximum(dist, 1e-7))
# Note: tf.atan2 takes in (y, x).
phi = tf.atan2(points[..., 1], points[..., 0])
return tf.stack([dist, theta, phi], axis=-1)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_cv import ops
class Boxes3DTestCase(tf.test.TestCase, parameterized.TestCase):
def test_convert_center_to_corners(self):
boxes = tf.constant(
[
[[1, 2, 3, 4, 3, 6, 0], [1, 2, 3, 4, 3, 6, 0]],
[[1, 2, 3, 4, 3, 6, np.pi / 2.0], [1, 2, 3, 4, 3, 6, np.pi / 2.0]],
]
)
corners = ops._center_xyzWHD_to_corner_xyz(boxes)
self.assertEqual((2, 2, 8, 3), corners.shape)
for i in [0, 1]:
self.assertAllClose(-1, np.min(corners[0, i, :, 0]))
self.assertAllClose(3, np.max(corners[0, i, :, 0]))
self.assertAllClose(0.5, np.min(corners[0, i, :, 1]))
self.assertAllClose(3.5, np.max(corners[0, i, :, 1]))
self.assertAllClose(0, np.min(corners[0, i, :, 2]))
self.assertAllClose(6, np.max(corners[0, i, :, 2]))
for i in [0, 1]:
self.assertAllClose(-0.5, np.min(corners[1, i, :, 0]))
self.assertAllClose(2.5, np.max(corners[1, i, :, 0]))
self.assertAllClose(0.0, np.min(corners[1, i, :, 1]))
self.assertAllClose(4.0, np.max(corners[1, i, :, 1]))
self.assertAllClose(0, np.min(corners[1, i, :, 2]))
self.assertAllClose(6, np.max(corners[1, i, :, 2]))
def test_within_box2d(self):
boxes = tf.constant(
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=tf.float32
)
points = tf.constant(
[
[-0.5, -0.5],
[0.5, -0.5],
[1.5, -0.5],
[1.5, 0.5],
[1.5, 1.5],
[0.5, 1.5],
[-0.5, 1.5],
[-0.5, 0.5],
[1.0, 1.0],
[0.5, 0.5],
],
dtype=tf.float32,
)
is_inside = ops.is_within_box2d(points, boxes)
expected = [[False]] * 8 + [[True]] * 2
self.assertAllEqual(expected, is_inside)
def test_within_zero_box2d(self):
bbox = tf.constant(
[[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]], dtype=tf.float32
)
points = tf.constant(
[
[-0.5, -0.5],
[0.5, -0.5],
[1.5, -0.5],
[1.5, 0.5],
[1.5, 1.5],
[0.5, 1.5],
[-0.5, 1.5],
[-0.5, 0.5],
[1.0, 1.0],
[0.5, 0.5],
],
dtype=tf.float32,
)
is_inside = ops.is_within_box2d(points, bbox)
expected = [[False]] * 10
self.assertAllEqual(expected, is_inside)
def test_is_on_lefthand_side(self):
v1 = tf.constant([[0.0, 0.0]], dtype=tf.float32)
v2 = tf.constant([[1.0, 0.0]], dtype=tf.float32)
p = tf.constant([[0.5, 0.5], [-1.0, -3], [-1.0, 1.0]], dtype=tf.float32)
res = ops._is_on_lefthand_side(p, v1, v2)
self.assertAllEqual([[True, False, True]], res)
res = ops._is_on_lefthand_side(v1, v1, v2)
self.assertAllEqual([[True]], res)
res = ops._is_on_lefthand_side(v2, v1, v2)
self.assertAllEqual([[True]], res)
@parameterized.named_parameters(
("without_rotation", 0.0),
("with_rotation_1_rad", 1.0),
("with_rotation_2_rad", 2.0),
("with_rotation_3_rad", 3.0),
)
def test_box_area(self, angle):
boxes = tf.constant(
[
[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
[[0.0, 0.0], [2.0, 0.0], [2.0, 1.0], [0.0, 1.0]],
[[0.0, 0.0], [2.0, 0.0], [2.0, 2.0], [0.0, 2.0]],
],
dtype=tf.float32,
)
expected = [[1.0], [2.0], [4.0]]
def _rotate(bbox, theta):
rotation_matrix = tf.reshape(
[tf.cos(theta), -tf.sin(theta), tf.sin(theta), tf.cos(theta)],
shape=(2, 2),
)
return tf.matmul(bbox, rotation_matrix)
rotated_bboxes = _rotate(boxes, angle)
res = ops._box_area(rotated_bboxes)
self.assertAllClose(expected, res)
def test_within_box3d(self):
num_points, num_boxes = 19, 4
# rotate the first box by pi / 2 so dim_x and dim_y are swapped.
# The last box is a cube rotated by 45 degrees.
bboxes = tf.constant(
[
[1.0, 2.0, 3.0, 6.0, 0.4, 6.0, np.pi / 2],
[4.0, 5.0, 6.0, 7.0, 0.8, 7.0, 0.0],
[0.4, 0.3, 0.2, 0.1, 0.1, 0.2, 0.0],
[-10.0, -10.0, -10.0, 3.0, 3.0, 3.0, np.pi / 4],
],
dtype=tf.float32,
)
points = tf.constant(
[
[1.0, 2.0, 3.0], # box 0 (centroid)
[0.8, 2.0, 3.0], # box 0 (below x)
[1.1, 2.0, 3.0], # box 0 (above x)
[1.3, 2.0, 3.0], # box 0 (too far x)
[0.7, 2.0, 3.0], # box 0 (too far x)
[4.0, 5.0, 6.0], # box 1 (centroid)
[4.0, 4.6, 6.0], # box 1 (below y)
[4.0, 5.4, 6.0], # box 1 (above y)
[4.0, 4.5, 6.0], # box 1 (too far y)
[4.0, 5.5, 6.0], # box 1 (too far y)
[0.4, 0.3, 0.2], # box 2 (centroid)
[0.4, 0.3, 0.1], # box 2 (below z)
[0.4, 0.3, 0.3], # box 2 (above z)
[0.4, 0.3, 0.0], # box 2 (too far z)
[0.4, 0.3, 0.4], # box 2 (too far z)
[5.0, 7.0, 8.0], # none
[1.0, 5.0, 3.6], # box0, box1
[-11.6, -10.0, -10.0], # box3 (rotated corner point).
[-11.4, -11.4, -10.0], # not in box3, would be if not rotated.
],
dtype=tf.float32,
)
expected_is_inside = np.array(
[
[True, False, False, False],
[True, False, False, False],
[True, False, False, False],
[False, False, False, False],
[False, False, False, False],
[False, True, False, False],
[False, True, False, False],
[False, True, False, False],
[False, False, False, False],
[False, False, False, False],
[False, False, True, False],
[False, False, True, False],
[False, False, True, False],
[False, False, False, False],
[False, False, False, False],
[False, False, False, False],
[True, True, False, False],
[False, False, False, True],
[False, False, False, False],
]
)
assert points.shape[0] == num_points
assert bboxes.shape[0] == num_boxes
assert expected_is_inside.shape[0] == num_points
assert expected_is_inside.shape[1] == num_boxes
is_inside = ops.is_within_box3d(points, bboxes)
self.assertAllEqual([num_points, num_boxes], is_inside.shape)
self.assertAllEqual(expected_is_inside, is_inside)
# Add a batch dimension to the data and see that it still works
# as expected.
batch_size = 3
points = tf.tile(points[tf.newaxis, ...], [batch_size, 1, 1])
bboxes = tf.tile(bboxes[tf.newaxis, ...], [batch_size, 1, 1])
is_inside = ops.is_within_box3d(points, bboxes)
self.assertAllEqual([batch_size, num_points, num_boxes], is_inside.shape)
for batch_idx in range(batch_size):
self.assertAllEqual(expected_is_inside, is_inside[batch_idx])
def testCoordinateTransform(self):
# This is a validated test case from a real scene.
#
# A single point [1, 1, 3].
point = tf.constant(
[[[5736.94580078, 1264.85168457, 45.0271225]]], dtype=tf.float32
)
# Replicate the point to test broadcasting behavior.
replicated_points = tf.tile(point, [2, 4, 1])
# Pose of the car (x, y, z, yaw, roll, pitch).
#
# We negate the translations so that the coordinates are translated
# such that the car is at the origin.
pose = tf.constant(
[
-5728.77148438,
-1264.42236328,
-45.06399918,
-3.10496902,
0.03288471,
0.00115049,
],
dtype=tf.float32,
)
result = ops.coordinate_transform(replicated_points, pose)
# We expect the point to be translated close to the car, and then rotated
# mostly around the x-axis.
# the result is device dependent, skip or ignore this test locally if it fails.
expected = np.tile([[[-8.184512, -0.13086952, -0.04200769]]], [2, 4, 1])
self.assertAllClose(expected, result)
def testSphericalCoordinatesTransform(self):
np_xyz = np.random.randn(5, 6, 3)
points = tf.constant(np_xyz, dtype=tf.float32)
spherical_coordinates = ops.spherical_coordinate_transform(points)
# Convert coordinates back to xyz to verify.
dist = spherical_coordinates[..., 0]
theta = spherical_coordinates[..., 1]
phi = spherical_coordinates[..., 2]
x = dist * np.sin(theta) * np.cos(phi)
y = dist * np.sin(theta) * np.sin(phi)
z = dist * np.cos(theta)
self.assertAllClose(x, np_xyz[..., 0])
self.assertAllClose(y, np_xyz[..., 1])
self.assertAllClose(z, np_xyz[..., 2])
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def balanced_sample(
positive_matches: tf.Tensor,
negative_matches: tf.Tensor,
num_samples: int,
positive_fraction: float,
):
"""
Sampling ops to balance positive and negative samples, deals with both
batched and unbatched inputs.
Args:
positive_matches: [N] or [batch_size, N] boolean Tensor, True for
indicating the index is a positive sample
negative_matches: [N] or [batch_size, N] boolean Tensor, True for
indicating the index is a negative sample
num_samples: int, representing the number of samples to collect
positive_fraction: float. 0.5 means positive samples should be half
of all collected samples.
Returns:
selected_indicators: [N] or [batch_size, N]
integer Tensor, 1 for indicating the index is sampled, 0 for
indicating the index is not sampled.
"""
N = positive_matches.get_shape().as_list()[-1]
if N < num_samples:
raise ValueError(
f"passed in {positive_matches.shape} has less element than {num_samples}"
)
# random_val = tf.random.uniform(tf.shape(positive_matches), minval=0., maxval=1.)
zeros = tf.zeros_like(positive_matches, dtype=tf.float32)
ones = tf.ones_like(positive_matches, dtype=tf.float32)
ones_rand = ones + tf.random.uniform(ones.shape, minval=-0.2, maxval=0.2)
halfs = 0.5 * tf.ones_like(positive_matches, dtype=tf.float32)
halfs_rand = halfs + tf.random.uniform(halfs.shape, minval=-0.2, maxval=0.2)
values = zeros
values = tf.where(positive_matches, ones_rand, values)
values = tf.where(negative_matches, halfs_rand, values)
num_pos_samples = int(num_samples * positive_fraction)
valid_matches = tf.logical_or(positive_matches, negative_matches)
# this might contain negative samples as well
_, positive_indices = tf.math.top_k(values, k=num_pos_samples)
selected_indicators = tf.cast(
tf.reduce_sum(tf.one_hot(positive_indices, depth=N), axis=-2), tf.bool
)
# setting all selected samples to zeros
values = tf.where(selected_indicators, zeros, values)
# setting all excessive positive matches to zeros as well
values = tf.where(positive_matches, zeros, values)
num_neg_samples = num_samples - num_pos_samples
_, negative_indices = tf.math.top_k(values, k=num_neg_samples)
selected_indices = tf.concat([positive_indices, negative_indices], axis=-1)
selected_indicators = tf.reduce_sum(tf.one_hot(selected_indices, depth=N), axis=-2)
selected_indicators = tf.minimum(
selected_indicators, tf.ones_like(selected_indicators)
)
selected_indicators = tf.where(
valid_matches, selected_indicators, tf.zeros_like(selected_indicators)
)
return selected_indicators
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from keras_cv.ops.sampling import balanced_sample
class BalancedSamplingTest(tf.test.TestCase):
def test_balanced_sampling(self):
positive_matches = tf.constant(
[True, False, False, False, False, False, False, False, False, False]
)
negative_matches = tf.constant(
[False, True, True, True, True, True, True, True, True, True]
)
num_samples = 5
positive_fraction = 0.2
res = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# The 1st element must be selected, given it's the only one.
self.assertAllClose(res[0], 1)
def test_balanced_batched_sampling(self):
positive_matches = tf.constant(
[
[True, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
]
)
negative_matches = tf.constant(
[
[False, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, False, True, True, True],
]
)
num_samples = 5
positive_fraction = 0.2
res = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# the 1st element from the 1st batch must be selected, given it's the only one
self.assertAllClose(res[0][0], 1)
# the 7th element from the 2nd batch must be selected, given it's the only one
self.assertAllClose(res[1][6], 1)
def test_balanced_sampling_over_positive_fraction(self):
positive_matches = tf.constant(
[True, False, False, False, False, False, False, False, False, False]
)
negative_matches = tf.constant(
[False, True, True, True, True, True, True, True, True, True]
)
num_samples = 5
positive_fraction = 0.4
res = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# only 1 positive sample exists, thus it is chosen
self.assertAllClose(res[0], 1)
def test_balanced_sampling_under_positive_fraction(self):
positive_matches = tf.constant(
[True, False, False, False, False, False, False, False, False, False]
)
negative_matches = tf.constant(
[False, True, True, True, True, True, True, True, True, True]
)
num_samples = 5
positive_fraction = 0.1
res = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# no positive is chosen
self.assertAllClose(res[0], 0)
self.assertAllClose(tf.reduce_sum(res), 5)
def test_balanced_sampling_over_num_samples(self):
positive_matches = tf.constant(
[True, False, False, False, False, False, False, False, False, False]
)
negative_matches = tf.constant(
[False, True, True, True, True, True, True, True, True, True]
)
# users want to get 20 samples, but only 10 are available
num_samples = 20
positive_fraction = 0.1
with self.assertRaisesRegex(ValueError, "has less element"):
_ = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
def test_balanced_sampling_no_positive(self):
positive_matches = tf.constant(
[False, False, False, False, False, False, False, False, False, False]
)
# the rest are neither positive nor negative, but ignord matches
negative_matches = tf.constant(
[False, False, True, False, False, True, False, False, True, False]
)
num_samples = 5
positive_fraction = 0.5
res = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# given only 3 negative and 0 positive, select all of them
self.assertAllClose(res, [0, 0, 1, 0, 0, 1, 0, 0, 1, 0])
def test_balanced_sampling_no_negative(self):
positive_matches = tf.constant(
[True, True, False, False, False, False, False, False, False, False]
)
# 2-9 indices are neither positive nor negative, they're ignored matches
negative_matches = tf.constant([False] * 10)
num_samples = 5
positive_fraction = 0.5
res = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# given only 2 positive and 0 negative, select all of them.
self.assertAllClose(res, [1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
def test_balanced_sampling_many_samples(self):
positive_matches = tf.random.uniform(
[2, 1000], minval=0, maxval=1, dtype=tf.float32
)
positive_matches = positive_matches > 0.98
negative_matches = tf.logical_not(positive_matches)
num_samples = 256
positive_fraction = 0.25
_ = balanced_sample(
positive_matches, negative_matches, num_samples, positive_fraction
)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import tensorflow as tf
def _target_gather(
targets: tf.Tensor,
indices: tf.Tensor,
mask: Optional[tf.Tensor] = None,
mask_val: Optional[float] = 0.0,
):
"""A utility function wrapping tf.gather, which deals with:
1) both batched and unbatched `targets`
2) when unbatched `targets` have empty rows, the result will be filled
with `mask_val`
3) target masking.
Args:
targets: [N, ...] or [batch_size, N, ...] Tensor representing targets such
as boxes, keypoints, etc.
indices: [M] or [batch_size, M] int32 Tensor representing indices within
`targets` to gather.
mask: optional [M, ...] or [batch_size, M, ...] boolean
Tensor representing the masking for each target. `True` means the corresponding
entity should be masked to `mask_val`, `False` means the corresponding entity
should be the target value.
mask_val: optinal float representing the masking value if `mask` is True on
the entity.
Returns:
targets: [M, ...] or [batch_size, M, ...] Tensor representing selected targets.
Raise:
ValueError: If `targets` is higher than rank 3.
"""
targets_shape = targets.get_shape().as_list()
if len(targets_shape) > 3:
raise ValueError(
"`target_gather` does not support `targets` with rank "
"larger than 3, got {}".format(len(targets.shape))
)
def _gather_unbatched(labels, match_indices, mask, mask_val):
"""Gather based on unbatched labels and boxes."""
num_gt_boxes = tf.shape(labels)[0]
def _assign_when_rows_empty():
if len(labels.shape) > 1:
mask_shape = [match_indices.shape[0], labels.shape[-1]]
else:
mask_shape = [match_indices.shape[0]]
return tf.cast(mask_val, labels.dtype) * tf.ones(
mask_shape, dtype=labels.dtype
)
def _assign_when_rows_not_empty():
targets = tf.gather(labels, match_indices)
if mask is None:
return targets
else:
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype
)
return tf.where(mask, masked_targets, targets)
return tf.cond(
tf.greater(num_gt_boxes, 0),
_assign_when_rows_not_empty,
_assign_when_rows_empty,
)
def _gather_batched(labels, match_indices, mask, mask_val):
"""Gather based on batched labels."""
batch_size = labels.shape[0]
if batch_size == 1:
if mask is not None:
result = _gather_unbatched(
tf.squeeze(labels, axis=0),
tf.squeeze(match_indices, axis=0),
tf.squeeze(mask, axis=0),
mask_val,
)
else:
result = _gather_unbatched(
tf.squeeze(labels, axis=0),
tf.squeeze(match_indices, axis=0),
None,
mask_val,
)
return tf.expand_dims(result, axis=0)
else:
indices_shape = tf.shape(match_indices)
indices_dtype = match_indices.dtype
batch_indices = tf.expand_dims(
tf.range(indices_shape[0], dtype=indices_dtype), axis=-1
) * tf.ones([1, indices_shape[-1]], dtype=indices_dtype)
gather_nd_indices = tf.stack([batch_indices, match_indices], axis=-1)
targets = tf.gather_nd(labels, gather_nd_indices)
if mask is None:
return targets
else:
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype
)
return tf.where(mask, masked_targets, targets)
if len(targets_shape) <= 2:
return _gather_unbatched(targets, indices, mask, mask_val)
elif len(targets_shape) == 3:
return _gather_batched(targets, indices, mask, mask_val)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from keras_cv.ops.target_gather import _target_gather
class TargetGatherTest(tf.test.TestCase):
def test_target_gather_boxes_batched(self):
target_boxes = tf.constant(
[[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]]
)
target_boxes = target_boxes[tf.newaxis, ...]
indices = tf.constant([[0, 2]], dtype=tf.int32)
expected_boxes = tf.constant([[0, 0, 5, 5], [5, 0, 10, 5]])
expected_boxes = expected_boxes[tf.newaxis, ...]
res = _target_gather(target_boxes, indices)
self.assertAllClose(expected_boxes, res)
def test_target_gather_boxes_unbatched(self):
target_boxes = tf.constant(
[[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]]
)
indices = tf.constant([0, 2], dtype=tf.int32)
expected_boxes = tf.constant([[0, 0, 5, 5], [5, 0, 10, 5]])
res = _target_gather(target_boxes, indices)
self.assertAllClose(expected_boxes, res)
def test_target_gather_classes_batched(self):
target_classes = tf.constant([[1, 2, 3, 4]])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([[0, 2]], dtype=tf.int32)
expected_classes = tf.constant([[1, 3]])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices)
self.assertAllClose(expected_classes, res)
def test_target_gather_classes_unbatched(self):
target_classes = tf.constant([1, 2, 3, 4])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([0, 2], dtype=tf.int32)
expected_classes = tf.constant([1, 3])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices)
self.assertAllClose(expected_classes, res)
def test_target_gather_classes_batched_with_mask(self):
target_classes = tf.constant([[1, 2, 3, 4]])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([[0, 2]], dtype=tf.int32)
masks = tf.constant(([[False, True]]))
masks = masks[..., tf.newaxis]
# the second element is masked
expected_classes = tf.constant([[1, 0]])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices, masks)
self.assertAllClose(expected_classes, res)
def test_target_gather_classes_batched_with_mask_val(self):
target_classes = tf.constant([[1, 2, 3, 4]])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([[0, 2]], dtype=tf.int32)
masks = tf.constant(([[False, True]]))
masks = masks[..., tf.newaxis]
# the second element is masked
expected_classes = tf.constant([[1, -1]])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices, masks, -1)
self.assertAllClose(expected_classes, res)
def test_target_gather_classes_unbatched_with_mask(self):
target_classes = tf.constant([1, 2, 3, 4])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([0, 2], dtype=tf.int32)
masks = tf.constant([False, True])
masks = masks[..., tf.newaxis]
expected_classes = tf.constant([1, 0])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices, masks)
self.assertAllClose(expected_classes, res)
def test_target_gather_with_empty_targets(self):
target_classes = tf.constant([])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([0, 2], dtype=tf.int32)
# return all 0s since input is empty
expected_classes = tf.constant([0, 0])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices)
self.assertAllClose(expected_classes, res)
def test_target_gather_classes_multi_batch(self):
target_classes = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
target_classes = target_classes[..., tf.newaxis]
indices = tf.constant([[0, 2], [1, 3]], dtype=tf.int32)
expected_classes = tf.constant([[1, 3], [6, 8]])
expected_classes = expected_classes[..., tf.newaxis]
res = _target_gather(target_classes, indices)
self.assertAllClose(expected_classes, res)
def test_target_gather_invalid_rank(self):
targets = tf.random.normal([32, 2, 2, 2])
indices = tf.constant([0, 1], dtype=tf.int32)
with self.assertRaisesRegex(ValueError, "larger than 3"):
_ = _target_gather(targets, indices)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.training.contrastive.contrastive_trainer import ContrastiveTrainer
from keras_cv.training.contrastive.simclr_trainer import SimCLRAugmenter
from keras_cv.training.contrastive.simclr_trainer import SimCLRTrainer
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from keras_cv.utils.train import convert_inputs_to_tf_dataset
class ContrastiveTrainer(keras.Model):
"""Creates a self-supervised contrastive trainer for a model.
Args:
encoder: a `keras.Model` to be pre-trained. In most cases, this encoder
should not include a top dense layer.
augmenter: a preprocessing layer to randomly augment input images for contrastive learning,
or a tuple of two separate augmenters for the two sides of the contrastive pipeline.
projector: a projection model for contrastive training, or a tuple of two separate
projectors for the two sides of the contrastive pipeline. This shrinks
the feature map produced by the encoder, and is usually a 1 or
2-layer dense MLP.
probe: An optional Keras layer or model which will be trained against
class labels at train-time using the encoder output as input.
Note that this should be specified iff training with labeled images.
This predicts class labels based on the feature map produced by the
encoder and is usually a 1 or 2-layer dense MLP.
Returns:
A `keras.Model` instance.
Usage:
```python
encoder = keras_cv.models.DenseNet121(include_rescaling=True, include_top=False, pooling="avg")
augmenter = keras_cv.layers.preprocessing.RandomFlip()
projector = keras.layers.Dense(64)
probe = keras_cv.training.ContrastiveTrainer.linear_probe(classes=10)
trainer = keras_cv.training.ContrastiveTrainer(
encoder=encoder,
augmenter=augmenter,
projector=projector,
probe=probe
)
trainer.compile(
encoder_optimizer=keras.optimizers.Adam(),
encoder_loss=keras_cv.losses.SimCLRLoss(temperature=0.5),
probe_optimizer=keras.optimizers.Adam(),
probe_loss=keras.losses.CategoricalCrossentropy(from_logits=True),
probe_metrics=[keras.metrics.CategoricalAccuracy(name="probe_accuracy")]
)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, 10)
trainer.fit(x_train, y_train)
```
"""
def __init__(
self,
encoder,
augmenter,
projector,
probe=None,
):
super().__init__()
if encoder.output.shape.rank != 2:
raise ValueError(
f"`encoder` must have a flattened output. Expected rank(encoder.output.shape)=2, got encoder.output.shape={encoder.output.shape}"
)
if type(augmenter) is tuple and len(augmenter) != 2:
raise ValueError(
"`augmenter` must be either a single augmenter or a tuple of exactly 2 augmenters."
)
if type(projector) is tuple and len(projector) != 2:
raise ValueError(
"`projector` must be either a single augmenter or a tuple of exactly 2 augmenters."
)
self.augmenters = (
augmenter if type(augmenter) is tuple else (augmenter, augmenter)
)
self.encoder = encoder
self.projectors = (
projector if type(projector) is tuple else (projector, projector)
)
self.probe = probe
self.loss_metric = keras.metrics.Mean(name="loss")
if probe is not None:
self.probe_loss_metric = keras.metrics.Mean(name="probe_loss")
self.probe_metrics = []
def compile(
self,
encoder_loss,
encoder_optimizer,
encoder_metrics=None,
probe_optimizer=None,
probe_loss=None,
probe_metrics=None,
**kwargs,
):
super().compile(
loss=encoder_loss,
optimizer=encoder_optimizer,
metrics=encoder_metrics,
**kwargs,
)
if self.probe and not probe_optimizer:
raise ValueError(
"`probe_optimizer` must be specified when a probe is included."
)
if self.probe and not probe_loss:
raise ValueError("`probe_loss` must be specified when a probe is included.")
if "loss" in kwargs:
raise ValueError(
"`loss` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_loss` or `probe_loss`."
)
if "optimizer" in kwargs:
raise ValueError(
"`optimizer` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_optimizer` or `probe_optimizer`."
)
if "metrics" in kwargs:
raise ValueError(
"`metrics` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_metrics` or `probe_metrics`."
)
if self.probe:
self.probe_loss = probe_loss
self.probe_optimizer = probe_optimizer
self.probe_metrics = probe_metrics or []
@property
def metrics(self):
metrics = [
self.loss_metric,
]
if self.probe:
metrics += [self.probe_loss_metric]
metrics += self.probe_metrics
return super().metrics + metrics
def fit(
self,
x=None,
y=None,
sample_weight=None,
batch_size=None,
**kwargs,
):
dataset = convert_inputs_to_tf_dataset(
x=x, y=y, sample_weight=sample_weight, batch_size=batch_size
)
dataset = dataset.map(self.run_augmenters, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return super().fit(x=dataset, **kwargs)
def run_augmenters(self, x, y=None):
inputs = {"images": x}
if y is not None:
inputs["labels"] = y
inputs["augmented_images_0"] = self.augmenters[0](x, training=True)
inputs["augmented_images_1"] = self.augmenters[1](x, training=True)
return inputs
def train_step(self, data):
images = data["images"]
labels = data["labels"] if "labels" in data else None
augmented_images_0 = data["augmented_images_0"]
augmented_images_1 = data["augmented_images_1"]
with tf.GradientTape() as tape:
features_0 = self.encoder(augmented_images_0, training=True)
features_1 = self.encoder(augmented_images_1, training=True)
projections_0 = self.projectors[0](features_0, training=True)
projections_1 = self.projectors[1](features_1, training=True)
loss = self.compiled_loss(
projections_0, projections_1, regularization_losses=self.encoder.losses
)
gradients = tape.gradient(
loss,
self.encoder.trainable_weights
+ self.projectors[0].trainable_weights
+ self.projectors[1].trainable_weights,
)
self.optimizer.apply_gradients(
zip(
gradients,
self.encoder.trainable_weights
+ self.projectors[0].trainable_weights
+ self.projectors[1].trainable_weights,
)
)
self.loss_metric.update_state(loss)
if self.probe:
if labels is None:
raise ValueError("Targets must be provided when a probe is specified")
with tf.GradientTape() as tape:
features = tf.stop_gradient(self.encoder(images, training=False))
class_logits = self.probe(features, training=True)
probe_loss = self.probe_loss(labels, class_logits)
gradients = tape.gradient(probe_loss, self.probe.trainable_weights)
self.probe_optimizer.apply_gradients(
zip(gradients, self.probe.trainable_weights)
)
self.probe_loss_metric.update_state(probe_loss)
for metric in self.probe_metrics:
metric.update_state(labels, class_logits)
return {metric.name: metric.result() for metric in self.metrics}
def call(self, inputs):
raise NotImplementedError(
"ContrastiveTrainer.call() is not implemented - please call your model directly."
)
@staticmethod
def linear_probe(classes, **kwargs):
return keras.Sequential(keras.layers.Dense(classes), **kwargs)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import metrics
from tensorflow.keras import optimizers
from keras_cv.layers import preprocessing
from keras_cv.losses import SimCLRLoss
from keras_cv.models import DenseNet121
from keras_cv.training import ContrastiveTrainer
class ContrastiveTrainerTest(tf.test.TestCase):
def test_probe_requires_probe_optimizer(self):
trainer = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=self.build_probe(),
)
with self.assertRaises(ValueError):
trainer.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
)
def test_targets_required_if_probing(self):
trainer_with_probing = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=self.build_probe(),
)
trainer_without_probing = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=None,
)
images = tf.random.uniform((1, 50, 50, 3))
trainer_with_probing.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
probe_optimizer=optimizers.Adam(),
probe_loss=keras.losses.CategoricalCrossentropy(from_logits=True),
)
trainer_without_probing.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
)
with self.assertRaises(ValueError):
trainer_with_probing.fit(images)
def test_train_with_probing(self):
trainer_with_probing = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=self.build_probe(classes=20),
)
images = tf.random.uniform((1, 50, 50, 3))
targets = tf.ones((1, 20))
trainer_with_probing.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
probe_metrics=[metrics.TopKCategoricalAccuracy(3, "top3_probe_accuracy")],
probe_optimizer=optimizers.Adam(),
probe_loss=keras.losses.CategoricalCrossentropy(from_logits=True),
)
trainer_with_probing.fit(images, targets)
def test_train_without_probing(self):
trainer_without_probing = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=None,
)
images = tf.random.uniform((1, 50, 50, 3))
targets = tf.ones((1, 20))
trainer_without_probing.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
)
trainer_without_probing.fit(images)
trainer_without_probing.fit(images, targets)
def test_inference_not_supported(self):
trainer = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=None,
)
trainer.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
)
with self.assertRaises(NotImplementedError):
trainer(tf.ones((1, 50, 50, 3)))
def test_encoder_must_have_flat_output(self):
with self.assertRaises(ValueError):
_ = ContrastiveTrainer(
# A DenseNet without pooling does not have a flat output
encoder=DenseNet121(include_rescaling=False, include_top=False),
augmenter=self.build_augmenter(),
projector=self.build_projector(),
probe=None,
)
def test_with_multiple_augmenters_and_projectors(self):
augmenter0 = preprocessing.RandomFlip("horizontal")
augmenter1 = preprocessing.RandomFlip("vertical")
projector0 = layers.Dense(64, name="projector0")
projector1 = keras.Sequential(
[projector0, layers.ReLU(), layers.Dense(64, name="projector1")]
)
trainer_without_probing = ContrastiveTrainer(
encoder=self.build_encoder(),
augmenter=(augmenter0, augmenter1),
projector=(projector0, projector1),
probe=None,
)
images = tf.random.uniform((1, 50, 50, 3))
trainer_without_probing.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
)
trainer_without_probing.fit(images)
def build_augmenter(self):
return preprocessing.RandomFlip("horizontal")
def build_encoder(self):
return DenseNet121(include_rescaling=False, include_top=False, pooling="avg")
def build_projector(self):
return layers.Dense(128)
def build_probe(self, classes=20):
return layers.Dense(classes)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow import keras
from tensorflow.keras import layers
from keras_cv.layers import preprocessing
from keras_cv.training import ContrastiveTrainer
class SimCLRTrainer(ContrastiveTrainer):
"""Creates a SimCLRTrainer.
References:
- [SimCLR paper](https://arxiv.org/pdf/2002.05709)
Args:
encoder: a `keras.Model` to be pre-trained. In most cases, this encoder
should not include a top dense layer.
augmenter: a SimCLRAugmenter layer to randomly augment input
images for contrastive learning
projection_width: the width of the two-layer dense model used for
projection in the SimCLR paper
"""
def __init__(self, encoder, augmenter, projection_width=128, **kwargs):
super().__init__(
encoder=encoder,
augmenter=augmenter,
projector=keras.Sequential(
[
layers.Dense(projection_width, activation="relu"),
layers.Dense(projection_width),
layers.BatchNormalization(),
],
name="projector",
),
**kwargs,
)
class SimCLRAugmenter(preprocessing.Augmenter):
def __init__(
self,
value_range,
height=128,
width=128,
crop_area_factor=(0.08, 1.0),
aspect_ratio_factor=(3 / 4, 4 / 3),
grayscale_rate=0.2,
color_jitter_rate=0.8,
brightness_factor=0.2,
contrast_factor=0.8,
saturation_factor=(0.3, 0.7),
hue_factor=0.2,
**kwargs,
):
return super().__init__(
[
preprocessing.RandomFlip("horizontal"),
preprocessing.RandomCropAndResize(
target_size=(height, width),
crop_area_factor=crop_area_factor,
aspect_ratio_factor=aspect_ratio_factor,
),
preprocessing.MaybeApply(
preprocessing.Grayscale(output_channels=3), rate=grayscale_rate
),
preprocessing.MaybeApply(
preprocessing.RandomColorJitter(
value_range=value_range,
brightness_factor=brightness_factor,
contrast_factor=contrast_factor,
saturation_factor=saturation_factor,
hue_factor=hue_factor,
),
rate=color_jitter_rate,
),
],
**kwargs,
)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.keras import optimizers
from keras_cv.losses import SimCLRLoss
from keras_cv.models import ResNet50V2
from keras_cv.training import SimCLRAugmenter
from keras_cv.training import SimCLRTrainer
class SimCLRTrainerTest(tf.test.TestCase):
def test_train_without_probing(self):
simclr_without_probing = SimCLRTrainer(
self.build_encoder(),
augmenter=SimCLRAugmenter(value_range=(0, 255)),
)
images = tf.random.uniform((10, 512, 512, 3))
simclr_without_probing.compile(
encoder_optimizer=optimizers.Adam(),
encoder_loss=SimCLRLoss(temperature=0.5),
)
simclr_without_probing.fit(images)
def build_encoder(self):
return ResNet50V2(include_rescaling=False, include_top=False, pooling="avg")
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.utils.fill_utils import fill_rectangle
from keras_cv.utils.preprocessing import blend
from keras_cv.utils.preprocessing import parse_factor
from keras_cv.utils.preprocessing import transform
from keras_cv.utils.preprocessing import transform_value_range
from keras_cv.utils.train import convert_inputs_to_tf_dataset
from keras_cv.utils.train import scale_loss_for_distribution
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def normalize_tuple(value, n, name, allow_zero=False):
"""Transforms non-negative/positive integer/integers into an integer tuple.
Args:
value: The value to validate and convert. Could an int, or any iterable of
ints.
n: The size of the tuple to be returned.
name: The name of the argument being validated, e.g. "strides" or
"kernel_size". This is only used to format error messages.
allow_zero: Default to False. A ValueError will raised if zero is received
and this param is False.
Returns:
A tuple of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof or a
negative value is
passed.
"""
error_msg = (
f"The `{name}` argument must be a tuple of {n} " f"integers. Received: {value}"
)
if isinstance(value, int):
value_tuple = (value,) * n
else:
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError(error_msg)
if len(value_tuple) != n:
raise ValueError(error_msg)
for single_value in value_tuple:
try:
int(single_value)
except (ValueError, TypeError):
error_msg += (
f"including element {single_value} of " f"type {type(single_value)}"
)
raise ValueError(error_msg)
if allow_zero:
unqualified_values = {v for v in value_tuple if v < 0}
req_msg = ">= 0"
else:
unqualified_values = {v for v in value_tuple if v <= 0}
req_msg = "> 0"
if unqualified_values:
error_msg += (
f" including {unqualified_values}"
f" that does not satisfy the requirement `{req_msg}`."
)
raise ValueError(error_msg)
return value_tuple
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from keras_cv import bounding_box
def _axis_mask(starts, ends, mask_len):
# index range of axis
batch_size = tf.shape(starts)[0]
axis_indices = tf.range(mask_len, dtype=starts.dtype)
axis_indices = tf.expand_dims(axis_indices, 0)
axis_indices = tf.tile(axis_indices, [batch_size, 1])
# mask of index bounds
axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends)
return axis_mask
def corners_to_mask(bounding_boxes, mask_shape):
"""Converts bounding boxes in corners format to boolean masks
Args:
bounding_boxes: tensor of rectangle coordinates with shape (batch_size, 4) in
corners format (x0, y0, x1, y1).
mask_shape: a shape tuple as (width, height) indicating the output
width and height of masks.
Returns:
boolean masks with shape (batch_size, width, height) where True values
indicate positions within bounding box coordinates.
"""
mask_width, mask_height = mask_shape
x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1)
w_mask = _axis_mask(x0, x1, mask_width)
h_mask = _axis_mask(y0, y1, mask_height)
w_mask = tf.expand_dims(w_mask, axis=1)
h_mask = tf.expand_dims(h_mask, axis=2)
masks = tf.logical_and(w_mask, h_mask)
return masks
def fill_rectangle(images, centers_x, centers_y, widths, heights, fill_values):
"""Fill rectangles with fill value into images.
Args:
images: Tensor of images to fill rectangles into.
centers_x: Tensor of positions of the rectangle centers on the x-axis.
centers_y: Tensor of positions of the rectangle centers on the y-axis.
widths: Tensor of widths of the rectangles
heights: Tensor of heights of the rectangles
fill_values: Tensor with same shape as images to get rectangle fill from.
Returns:
images with filled rectangles.
"""
images_shape = tf.shape(images)
images_height = images_shape[1]
images_width = images_shape[2]
xywh = tf.stack([centers_x, centers_y, widths, heights], axis=1)
xywh = tf.cast(xywh, tf.float32)
corners = bounding_box.convert_format(xywh, source="center_xywh", target="xyxy")
mask_shape = (images_width, images_height)
is_rectangle = corners_to_mask(corners, mask_shape)
is_rectangle = tf.expand_dims(is_rectangle, -1)
images = tf.where(is_rectangle, fill_values, images)
return images
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from keras_cv.utils import fill_utils
class BoundingBoxToMaskTest(tf.test.TestCase):
def _run_test(self, corners, expected):
mask = fill_utils.corners_to_mask(corners, mask_shape=(6, 6))
mask = tf.cast(mask, dtype=tf.int32)
tf.assert_equal(mask, expected)
def test_corners_whole(self):
expected = tf.constant(
[
[0, 1, 1, 1, 0, 0],
[0, 1, 1, 1, 0, 0],
[0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
],
dtype=tf.int32,
)
corners = tf.constant([[1, 0, 4, 3]], dtype=tf.float32)
self._run_test(corners, expected)
def test_corners_frac(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[1.5, 0.5, 4.5, 3.5]], dtype=tf.float32)
self._run_test(corners, expected)
def test_width_zero(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[0, 0, 0, 3]], dtype=tf.float32)
self._run_test(corners, expected)
def test_height_zero(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[1, 0, 4, 0]], dtype=tf.float32)
self._run_test(corners, expected)
def test_width_negative(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[1, 0, -2, 3]], dtype=tf.float32)
self._run_test(corners, expected)
def test_height_negative(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[1, 0, 4, -2]], dtype=tf.float32)
self._run_test(corners, expected)
def test_width_out_of_lower_bound(self):
expected = tf.constant(
[
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[-2, -2, 2, 3]], dtype=tf.float32)
self._run_test(corners, expected)
def test_width_out_of_upper_bound(self):
expected = tf.constant(
[
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[4, 0, 8, 3]], dtype=tf.float32)
self._run_test(corners, expected)
def test_height_out_of_lower_bound(self):
expected = tf.constant(
[
[0, 1, 1, 1, 0, 0],
[0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[1, -3, 4, 2]], dtype=tf.float32)
self._run_test(corners, expected)
def test_height_out_of_upper_bound(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0],
[0, 1, 1, 1, 0, 0],
]
)
corners = tf.constant([[1, 4, 4, 9]], dtype=tf.float32)
self._run_test(corners, expected)
def test_start_out_of_upper_bound(self):
expected = tf.constant(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)
corners = tf.constant([[8, 8, 10, 12]], dtype=tf.float32)
self._run_test(corners, expected)
class FillRectangleTest(tf.test.TestCase):
def _run_test(self, img_w, img_h, cent_x, cent_y, rec_w, rec_h, expected):
batch_size = 1
batch_shape = (batch_size, img_h, img_w, 1)
images = tf.ones(batch_shape, dtype=tf.int32)
centers_x = tf.fill([batch_size], cent_x)
centers_y = tf.fill([batch_size], cent_y)
width = tf.fill([batch_size], rec_w)
height = tf.fill([batch_size], rec_h)
fill = tf.zeros_like(images)
filled_images = fill_utils.fill_rectangle(
images, centers_x, centers_y, width, height, fill
)
# remove batch dimension and channel dimension
filled_images = filled_images[0, ..., 0]
tf.assert_equal(filled_images, expected)
def test_rectangle_position(self):
img_w, img_h = 8, 8
cent_x, cent_y = 4, 3
rec_w, rec_h = 5, 3
expected = tf.constant(
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 0, 0, 1],
[1, 1, 0, 0, 0, 0, 0, 1],
[1, 1, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
],
dtype=tf.int32,
)
self._run_test(img_w, img_h, cent_x, cent_y, rec_w, rec_h, expected)
def test_width_out_of_lower_bound(self):
img_w, img_h = 8, 8
cent_x, cent_y = 1, 3
rec_w, rec_h = 5, 3
# assert width is truncated when cent_x - rec_w < 0
expected = tf.constant(
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
],
dtype=tf.int32,
)
self._run_test(img_w, img_h, cent_x, cent_y, rec_w, rec_h, expected)
def test_width_out_of_upper_bound(self):
img_w, img_h = 8, 8
cent_x, cent_y = 6, 3
rec_w, rec_h = 5, 3
# assert width is truncated when cent_x + rec_w > img_w
expected = tf.constant(
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
],
dtype=tf.int32,
)
self._run_test(img_w, img_h, cent_x, cent_y, rec_w, rec_h, expected)
def test_height_out_of_lower_bound(self):
img_w, img_h = 8, 8
cent_x, cent_y = 4, 1
rec_w, rec_h = 3, 5
# assert height is truncated when cent_y - rec_h < 0
expected = tf.constant(
[
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
],
dtype=tf.int32,
)
self._run_test(img_w, img_h, cent_x, cent_y, rec_w, rec_h, expected)
def test_height_out_of_upper_bound(self):
img_w, img_h = 8, 8
cent_x, cent_y = 4, 6
rec_w, rec_h = 3, 5
# assert height is truncated when cent_y + rec_h > img_h
expected = tf.constant(
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
],
dtype=tf.int32,
)
self._run_test(img_w, img_h, cent_x, cent_y, rec_w, rec_h, expected)
def test_different_fill(self):
batch_size = 2
img_w, img_h = 5, 5
cent_x, cent_y = 2, 2
rec_w, rec_h = 3, 3
batch_shape = (batch_size, img_h, img_w, 1)
images = tf.ones(batch_shape, dtype=tf.int32)
centers_x = tf.fill([batch_size], cent_x)
centers_y = tf.fill([batch_size], cent_y)
width = tf.fill([batch_size], rec_w)
height = tf.fill([batch_size], rec_h)
fill = tf.stack([tf.fill(images[0].shape, 2), tf.fill(images[1].shape, 3)])
filled_images = fill_utils.fill_rectangle(
images, centers_x, centers_y, width, height, fill
)
# remove channel dimension
filled_images = filled_images[..., 0]
expected = tf.constant(
[
[
[1, 1, 1, 1, 1],
[1, 2, 2, 2, 1],
[1, 2, 2, 2, 1],
[1, 2, 2, 2, 1],
[1, 1, 1, 1, 1],
],
[
[1, 1, 1, 1, 1],
[1, 3, 3, 3, 1],
[1, 3, 3, 3, 1],
[1, 3, 3, 3, 1],
[1, 1, 1, 1, 1],
],
],
dtype=tf.int32,
)
tf.assert_equal(filled_images, expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment