Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
1a3c83d6
Commit
1a3c83d6
authored
Jan 10, 2023
by
zhanggzh
Browse files
增加keras-cv模型及训练代码
parent
9846958a
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2603 additions
and
0 deletions
+2603
-0
Keras/keras-cv/keras_cv/ops/box_matcher.py
Keras/keras-cv/keras_cv/ops/box_matcher.py
+240
-0
Keras/keras-cv/keras_cv/ops/box_matcher_test.py
Keras/keras-cv/keras_cv/ops/box_matcher_test.py
+127
-0
Keras/keras-cv/keras_cv/ops/iou_3d.py
Keras/keras-cv/keras_cv/ops/iou_3d.py
+48
-0
Keras/keras-cv/keras_cv/ops/iou_3d_test.py
Keras/keras-cv/keras_cv/ops/iou_3d_test.py
+54
-0
Keras/keras-cv/keras_cv/ops/point_cloud.py
Keras/keras-cv/keras_cv/ops/point_cloud.py
+311
-0
Keras/keras-cv/keras_cv/ops/point_cloud_test.py
Keras/keras-cv/keras_cv/ops/point_cloud_test.py
+261
-0
Keras/keras-cv/keras_cv/ops/sampling.py
Keras/keras-cv/keras_cv/ops/sampling.py
+78
-0
Keras/keras-cv/keras_cv/ops/sampling_test.py
Keras/keras-cv/keras_cv/ops/sampling_test.py
+145
-0
Keras/keras-cv/keras_cv/ops/target_gather.py
Keras/keras-cv/keras_cv/ops/target_gather.py
+124
-0
Keras/keras-cv/keras_cv/ops/target_gather_test.py
Keras/keras-cv/keras_cv/ops/target_gather_test.py
+117
-0
Keras/keras-cv/keras_cv/training/__init__.py
Keras/keras-cv/keras_cv/training/__init__.py
+17
-0
Keras/keras-cv/keras_cv/training/contrastive/__init__.py
Keras/keras-cv/keras_cv/training/contrastive/__init__.py
+13
-0
Keras/keras-cv/keras_cv/training/contrastive/contrastive_trainer.py
...s-cv/keras_cv/training/contrastive/contrastive_trainer.py
+253
-0
Keras/keras-cv/keras_cv/training/contrastive/contrastive_trainer_test.py
...keras_cv/training/contrastive/contrastive_trainer_test.py
+171
-0
Keras/keras-cv/keras_cv/training/contrastive/simclr_trainer.py
.../keras-cv/keras_cv/training/contrastive/simclr_trainer.py
+92
-0
Keras/keras-cv/keras_cv/training/contrastive/simclr_trainer_test.py
...s-cv/keras_cv/training/contrastive/simclr_trainer_test.py
+40
-0
Keras/keras-cv/keras_cv/utils/__init__.py
Keras/keras-cv/keras_cv/utils/__init__.py
+21
-0
Keras/keras-cv/keras_cv/utils/conv_utils.py
Keras/keras-cv/keras_cv/utils/conv_utils.py
+69
-0
Keras/keras-cv/keras_cv/utils/fill_utils.py
Keras/keras-cv/keras_cv/utils/fill_utils.py
+81
-0
Keras/keras-cv/keras_cv/utils/fill_utils_test.py
Keras/keras-cv/keras_cv/utils/fill_utils_test.py
+341
-0
No files found.
Keras/keras-cv/keras_cv/ops/box_matcher.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"Argmax-based box matching"
from
typing
import
List
from
typing
import
Tuple
import
tensorflow
as
tf
class
ArgmaxBoxMatcher
:
"""Box matching logic based on argmax of highest value (e.g., IOU).
This class computes matches from a similarity matrix. Each row will be
matched to at least one column, the matched result can either be positive
/ negative, or simply ignored depending on the setting.
The settings include `thresholds` and `match_values`, for example if:
1) thresholds=[negative_threshold, positive_threshold], and
match_values=[negative_value=0, ignore_value=-1, positive_value=1]: the rows will
be assigned to positive_value if its argmax result >=
positive_threshold; the rows will be assigned to negative_value if its
argmax result < negative_threshold, and the rows will be assigned
to ignore_value if its argmax result is between [negative_threshold, positive_threshold).
2) thresholds=[negative_threshold, positive_threshold], and
match_values=[ignore_value=-1, negative_value=0, positive_value=1]: the rows will
be assigned to positive_value if its argmax result >=
positive_threshold; the rows will be assigned to ignore_value if its
argmax result < negative_threshold, and the rows will be assigned
to negative_value if its argmax result is between [negative_threshold ,positive_threshold).
This is different from case 1) by swapping first two
values.
3) thresholds=[positive_threshold], and
match_values=[negative_values, positive_value]: the rows will be assigned to
positive value if its argmax result >= positive_threshold; the rows
will be assigned to negative_value if its argmax result < negative_threshold.
Args:
thresholds: A sorted list of floats to classify the matches into
different results (e.g. positive or negative or ignored match). The
list will be prepended with -Inf and and appended with +Inf.
match_values: A list of integers representing matched results (e.g.
positive or negative or ignored match). len(`match_values`) must
equal to len(`thresholds`) + 1.
force_match_for_each_col: each row will be argmax matched to at
least one column. This means some columns will be matched to
multiple rows while some columns will not be matched to any rows.
Filtering by `thresholds` will make less columns match to positive
result. Setting this to True guarantees that each column will be
matched to positive result to at least one row.
Raises:
ValueError: if `thresholds` not sorted or
len(`match_values`) != len(`thresholds`) + 1
Usage:
```python
box_matcher = keras_cv.ops.ArgmaxBoxMatcher([0.3, 0.7], [-1, 0, 1])
iou_metric = keras_cv.bounding_box.compute_iou(anchors, gt_boxes)
matched_columns, matched_match_values = box_matcher(iou_metric)
cls_mask = tf.less_equal(matched_match_values, 0)
```
TODO(tanzhenyu): document when to use which mode.
"""
def
__init__
(
self
,
thresholds
:
List
[
float
],
match_values
:
List
[
int
],
force_match_for_each_col
:
bool
=
False
,
):
if
sorted
(
thresholds
)
!=
thresholds
:
raise
ValueError
(
f
"`threshold` must be sorted, got
{
thresholds
}
"
)
self
.
match_values
=
match_values
if
len
(
match_values
)
!=
len
(
thresholds
)
+
1
:
raise
ValueError
(
f
"len(`match_values`) must be len(`thresholds`) + 1, got "
f
"match_values
{
match_values
}
, thresholds
{
thresholds
}
"
)
thresholds
.
insert
(
0
,
-
float
(
"inf"
))
thresholds
.
append
(
float
(
"inf"
))
self
.
thresholds
=
thresholds
self
.
force_match_for_each_col
=
force_match_for_each_col
def
__call__
(
self
,
similarity_matrix
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Matches each row to a column based on argmax
TODO(tanzhenyu): consider swapping rows and cols.
Args:
similarity_matrix: A float Tensor of shape [num_rows, num_cols] or
[batch_size, num_rows, num_cols] representing any similarity metric.
Returns:
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched colum for each row.
matched_values: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match result (positive match, negative match,
ignored match).
"""
squeeze_result
=
False
if
len
(
similarity_matrix
.
shape
)
==
2
:
squeeze_result
=
True
similarity_matrix
=
tf
.
expand_dims
(
similarity_matrix
,
axis
=
0
)
static_shape
=
similarity_matrix
.
shape
.
as_list
()
num_rows
=
static_shape
[
1
]
or
tf
.
shape
(
similarity_matrix
)[
1
]
batch_size
=
static_shape
[
0
]
or
tf
.
shape
(
similarity_matrix
)[
0
]
def
_match_when_cols_are_empty
():
"""Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the rows do not match to any columns.
Returns:
matched_columns: An integer tensor of shape [batch_size, num_rows]
storing the index of the matched column for each row.
matched_values: An integer tensor of shape [batch_size, num_rows]
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with
tf
.
name_scope
(
"empty_gt_boxes"
):
matched_columns
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
matched_values
=
-
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
return
matched_columns
,
matched_values
def
_match_when_cols_are_non_empty
():
"""Performs matching when the rows of similarity matrix are non empty.
Returns:
matched_columns: An integer tensor of shape [batch_size, num_rows]
storing the index of the matched column for each row.
matched_values: An integer tensor of shape [batch_size, num_rows]
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with
tf
.
name_scope
(
"non_empty_gt_boxes"
):
matched_columns
=
tf
.
argmax
(
similarity_matrix
,
axis
=-
1
,
output_type
=
tf
.
int32
)
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals
=
tf
.
reduce_max
(
similarity_matrix
,
axis
=-
1
)
matched_values
=
tf
.
zeros
([
batch_size
,
num_rows
],
tf
.
int32
)
match_dtype
=
matched_vals
.
dtype
for
(
ind
,
low
,
high
)
in
zip
(
self
.
match_values
,
self
.
thresholds
[:
-
1
],
self
.
thresholds
[
1
:]
):
low_threshold
=
tf
.
cast
(
low
,
match_dtype
)
high_threshold
=
tf
.
cast
(
high
,
match_dtype
)
mask
=
tf
.
logical_and
(
tf
.
greater_equal
(
matched_vals
,
low_threshold
),
tf
.
less
(
matched_vals
,
high_threshold
),
)
matched_values
=
self
.
_set_values_using_indicator
(
matched_values
,
mask
,
ind
)
if
self
.
force_match_for_each_col
:
# [batch_size, num_cols], for each column (groundtruth_box), find the
# best matching row (anchor).
matching_rows
=
tf
.
argmax
(
input
=
similarity_matrix
,
axis
=
1
,
output_type
=
tf
.
int32
)
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M,
# where M[j, i] = 1 means column j is matched to row i.
column_to_row_match_mapping
=
tf
.
one_hot
(
matching_rows
,
depth
=
num_rows
)
# [batch_size, num_rows], for each row (anchor), find the matched
# column (groundtruth_box).
force_matched_columns
=
tf
.
argmax
(
input
=
column_to_row_match_mapping
,
axis
=
1
,
output_type
=
tf
.
int32
)
# [batch_size, num_rows]
force_matched_column_mask
=
tf
.
cast
(
tf
.
reduce_max
(
column_to_row_match_mapping
,
axis
=
1
),
tf
.
bool
)
# [batch_size, num_rows]
matched_columns
=
tf
.
where
(
force_matched_column_mask
,
force_matched_columns
,
matched_columns
,
)
matched_values
=
tf
.
where
(
force_matched_column_mask
,
self
.
match_values
[
-
1
]
*
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
matched_values
,
)
return
matched_columns
,
matched_values
num_gt_boxes
=
(
similarity_matrix
.
shape
.
as_list
()[
-
1
]
or
tf
.
shape
(
similarity_matrix
)[
-
1
]
)
matched_columns
,
matched_values
=
tf
.
cond
(
pred
=
tf
.
greater
(
num_gt_boxes
,
0
),
true_fn
=
_match_when_cols_are_non_empty
,
false_fn
=
_match_when_cols_are_empty
,
)
if
squeeze_result
:
matched_columns
=
tf
.
squeeze
(
matched_columns
,
axis
=
0
)
matched_values
=
tf
.
squeeze
(
matched_values
,
axis
=
0
)
return
matched_columns
,
matched_values
def
_set_values_using_indicator
(
self
,
x
,
indicator
,
val
):
"""Set the indicated fields of x to val.
Args:
x: tensor.
indicator: boolean with same shape as x.
val: scalar with value to set.
Returns:
modified tensor.
"""
indicator
=
tf
.
cast
(
indicator
,
x
.
dtype
)
return
tf
.
add
(
tf
.
multiply
(
x
,
1
-
indicator
),
val
*
indicator
)
def
get_config
(
self
):
config
=
{
"thresholds"
:
self
.
thresholds
[
1
:
-
1
],
"match_values"
:
self
.
match_values
,
"force_match_for_each_col"
:
self
.
force_match_for_each_col
,
}
return
config
Keras/keras-cv/keras_cv/ops/box_matcher_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
keras_cv.ops.box_matcher
import
ArgmaxBoxMatcher
class
ArgmaxBoxMatcherTest
(
tf
.
test
.
TestCase
):
def
test_box_matcher_invalid_length
(
self
):
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
with
self
.
assertRaisesRegex
(
ValueError
,
"must be len"
):
_
=
ArgmaxBoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
match_values
=
[
-
3
,
-
2
,
-
1
],
)
def
test_box_matcher_unsorted_thresholds
(
self
):
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
with
self
.
assertRaisesRegex
(
ValueError
,
"must be sorted"
):
_
=
ArgmaxBoxMatcher
(
thresholds
=
[
bg_thresh_hi
,
bg_thresh_lo
,
fg_threshold
],
match_values
=
[
-
3
,
-
2
,
-
1
,
1
],
)
def
test_box_matcher_unbatched
(
self
):
sim_matrix
=
tf
.
constant
([[
0.04
,
0
,
0
,
0
],
[
0
,
0
,
1.0
,
0
]],
dtype
=
tf
.
float32
)
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
matcher
=
ArgmaxBoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
match_values
=
[
-
3
,
-
2
,
-
1
,
1
],
)
match_indices
,
matched_values
=
matcher
(
sim_matrix
)
positive_matches
=
tf
.
greater_equal
(
matched_values
,
0
)
negative_matches
=
tf
.
equal
(
matched_values
,
-
2
)
self
.
assertAllEqual
(
positive_matches
.
numpy
(),
[
False
,
True
])
self
.
assertAllEqual
(
negative_matches
.
numpy
(),
[
True
,
False
])
self
.
assertAllEqual
(
match_indices
.
numpy
(),
[
0
,
2
])
self
.
assertAllEqual
(
matched_values
.
numpy
(),
[
-
2
,
1
])
def
test_box_matcher_batched
(
self
):
sim_matrix
=
tf
.
constant
([[[
0.04
,
0
,
0
,
0
],
[
0
,
0
,
1.0
,
0
]]],
dtype
=
tf
.
float32
)
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
matcher
=
ArgmaxBoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
match_values
=
[
-
3
,
-
2
,
-
1
,
1
],
)
match_indices
,
matched_values
=
matcher
(
sim_matrix
)
positive_matches
=
tf
.
greater_equal
(
matched_values
,
0
)
negative_matches
=
tf
.
equal
(
matched_values
,
-
2
)
self
.
assertAllEqual
(
positive_matches
.
numpy
(),
[[
False
,
True
]])
self
.
assertAllEqual
(
negative_matches
.
numpy
(),
[[
True
,
False
]])
self
.
assertAllEqual
(
match_indices
.
numpy
(),
[[
0
,
2
]])
self
.
assertAllEqual
(
matched_values
.
numpy
(),
[[
-
2
,
1
]])
def
test_box_matcher_force_match
(
self
):
sim_matrix
=
tf
.
constant
(
[[
0
,
0.04
,
0
,
0.1
],
[
0
,
0
,
1.0
,
0
],
[
0.1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0.6
]],
dtype
=
tf
.
float32
,
)
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
matcher
=
ArgmaxBoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
match_values
=
[
-
3
,
-
2
,
-
1
,
1
],
force_match_for_each_col
=
True
,
)
match_indices
,
matched_values
=
matcher
(
sim_matrix
)
positive_matches
=
tf
.
greater_equal
(
matched_values
,
0
)
negative_matches
=
tf
.
equal
(
matched_values
,
-
2
)
self
.
assertAllEqual
(
positive_matches
.
numpy
(),
[
True
,
True
,
True
,
True
])
self
.
assertAllEqual
(
negative_matches
.
numpy
(),
[
False
,
False
,
False
,
False
])
# the first anchor cannot be matched to 4th gt box given that is matched to
# the last anchor.
self
.
assertAllEqual
(
match_indices
.
numpy
(),
[
1
,
2
,
0
,
3
])
self
.
assertAllEqual
(
matched_values
.
numpy
(),
[
1
,
1
,
1
,
1
])
def
test_box_matcher_empty_gt_boxes
(
self
):
sim_matrix
=
tf
.
constant
([[],
[]],
dtype
=
tf
.
float32
)
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
matcher
=
ArgmaxBoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
match_values
=
[
-
3
,
-
2
,
-
1
,
1
],
)
match_indices
,
matched_values
=
matcher
(
sim_matrix
)
positive_matches
=
tf
.
greater_equal
(
matched_values
,
0
)
ignore_matches
=
tf
.
equal
(
matched_values
,
-
1
)
self
.
assertAllEqual
(
positive_matches
.
numpy
(),
[
False
,
False
])
self
.
assertAllEqual
(
ignore_matches
.
numpy
(),
[
True
,
True
])
self
.
assertAllEqual
(
match_indices
.
numpy
(),
[
0
,
0
])
self
.
assertAllEqual
(
matched_values
.
numpy
(),
[
-
1
,
-
1
])
Keras/keras-cv/keras_cv/ops/iou_3d.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IoU3D using a custom TF op."""
from
tensorflow.python.framework
import
load_library
from
tensorflow.python.platform
import
resource_loader
class
IoU3D
:
"""Implements IoU computation for 3D upright rotated bounding boxes.
Note that this is implemented using a custom TensorFlow op. Initializing an
IoU3D object will attempt to load the binary for that op.
Boxes should have the format [center_x, center_y, center_z, dimension_x,
dimension_y, dimension_z, heading (in radians)].
Sample Usage:
```python
y_true = [[0, 0, 0, 2, 2, 2, 0], [1, 1, 1, 2, 2, 2, 3 * math.pi / 4]]
y_pred = [[1, 1, 1, 2, 2, 2, math.pi / 4], [1, 1, 1, 2, 2, 2, 0]]
iou = IoU3D()
iou(y_true, y_pred)
```
"""
def
__init__
(
self
):
pairwise_iou_op
=
load_library
.
load_op_library
(
resource_loader
.
get_path_to_datafile
(
"../custom_ops/_keras_cv_custom_ops.so"
)
)
self
.
iou_3d
=
pairwise_iou_op
.
pairwise_iou3d
def
__call__
(
self
,
y_true
,
y_pred
):
return
self
.
iou_3d
(
y_true
,
y_pred
)
Keras/keras-cv/keras_cv/ops/iou_3d_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for IoU3D using custom op."""
import
math
import
os
import
pytest
import
tensorflow
as
tf
from
keras_cv.ops
import
IoU3D
class
IoU3DTest
(
tf
.
test
.
TestCase
):
@
pytest
.
mark
.
skipif
(
"TEST_CUSTOM_OPS"
not
in
os
.
environ
or
os
.
environ
[
"TEST_CUSTOM_OPS"
]
!=
"true"
,
reason
=
"Requires binaries compiled from source"
,
)
def
testOpCall
(
self
):
# Predicted boxes:
# 0: a 2x2x2 box centered at 0,0,0, rotated 0 degrees
# 1: a 2x2x2 box centered at 1,1,1, rotated 135 degrees
# Ground Truth boxes:
# 0: a 2x2x2 box centered at 1,1,1, rotated 45 degrees (idential to predicted box 1)
# 1: a 2x2x2 box centered at 1,1,1, rotated 0 degrees
box_preds
=
[[
0
,
0
,
0
,
2
,
2
,
2
,
0
],
[
1
,
1
,
1
,
2
,
2
,
2
,
3
*
math
.
pi
/
4
]]
box_gt
=
[[
1
,
1
,
1
,
2
,
2
,
2
,
math
.
pi
/
4
],
[
1
,
1
,
1
,
2
,
2
,
2
,
0
]]
# Predicted box 0 and both ground truth boxes overlap by 1/8th of the box.
# Therefore, IiU is 1/15
# Predicted box 1 is the same as ground truth box 0, therefore IoU is 1
# Predicted box 1 shares an origin with ground truth box 1, but is rotated by 135 degrees.
# Their IoU can be reduced to that of two overlapping squares that share a center with
# the same offset of 135 degrees, which reduces to the square root of 0.5.
expected_ious
=
[[
1
/
15
,
1
/
15
],
[
1
,
0.5
**
0.5
]]
iou_3d
=
IoU3D
()
self
.
assertAllClose
(
iou_3d
(
box_preds
,
box_gt
),
expected_ious
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Keras/keras-cv/keras_cv/ops/point_cloud.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
def
get_rank
(
tensor
):
return
tensor
.
shape
.
ndims
or
tf
.
rank
(
tensor
)
def
_get_3d_rotation_matrix
(
yaw
,
roll
,
pitch
):
"""Creates 3x3 rotation matrix from yaw, roll, pitch (angles in radians).
Note: Yaw -> Z, Roll -> X, Pitch -> Y
Args:
yaw: float tensor representing a yaw angle in radians.
roll: float tensor representing a roll angle in radians.
pitch: float tensor representing a pitch angle in radians.
Returns:
A [3, 3] tensor corresponding to a rotation matrix.
"""
def
_UnitX
(
angle
):
return
tf
.
reshape
(
[
1.0
,
0.0
,
0.0
,
0.0
,
tf
.
cos
(
angle
),
-
tf
.
sin
(
angle
),
0.0
,
tf
.
sin
(
angle
),
tf
.
cos
(
angle
),
],
shape
=
[
3
,
3
],
)
def
_UnitY
(
angle
):
return
tf
.
reshape
(
[
tf
.
cos
(
angle
),
0.0
,
tf
.
sin
(
angle
),
0.0
,
1.0
,
0.0
,
-
tf
.
sin
(
angle
),
0.0
,
tf
.
cos
(
angle
),
],
shape
=
[
3
,
3
],
)
def
_UnitZ
(
angle
):
return
tf
.
reshape
(
[
tf
.
cos
(
angle
),
-
tf
.
sin
(
angle
),
0.0
,
tf
.
sin
(
angle
),
tf
.
cos
(
angle
),
0.0
,
0.0
,
0.0
,
1.0
,
],
shape
=
[
3
,
3
],
)
return
tf
.
matmul
(
tf
.
matmul
(
_UnitZ
(
yaw
),
_UnitX
(
roll
)),
_UnitY
(
pitch
))
def
_center_xyzWHD_to_corner_xyz
(
boxes
):
"""convert from center format to corner format.
Args:
boxes: [..., num_boxes, 7] float32 Tensor for 3d boxes in [x, y, z, dx,
dy, dz, phi].
Returns:
corners: [..., num_boxes, 8, 3] float32 Tensor for 3d corners in [x, y, z].
"""
# relative corners w.r.t to origin point
# this will return all corners in top-down counter clockwise instead of
# only left top and bottom right.
rel_corners
=
tf
.
constant
(
[
[
0.5
,
0.5
,
0.5
],
# top
[
-
0.5
,
0.5
,
0.5
],
# top
[
-
0.5
,
-
0.5
,
0.5
],
# top
[
0.5
,
-
0.5
,
0.5
],
# top
[
0.5
,
0.5
,
-
0.5
],
# bottom
[
-
0.5
,
0.5
,
-
0.5
],
# bottom
[
-
0.5
,
-
0.5
,
-
0.5
],
# bottom
[
0.5
,
-
0.5
,
-
0.5
],
# bottom
]
)
centers
=
boxes
[...,
:
3
]
dimensions
=
boxes
[...,
3
:
6
]
phi_world
=
boxes
[...,
6
]
leading_shapes
=
boxes
.
shape
.
as_list
()[:
-
1
]
cos
=
tf
.
cos
(
phi_world
)
sin
=
tf
.
sin
(
phi_world
)
zero
=
tf
.
zeros_like
(
cos
)
one
=
tf
.
ones_like
(
cos
)
rotations
=
tf
.
reshape
(
tf
.
stack
([
cos
,
-
sin
,
zero
,
sin
,
cos
,
zero
,
zero
,
zero
,
one
],
axis
=-
1
),
leading_shapes
+
[
3
,
3
],
)
# apply the delta to convert from centers to relative corners format
rel_corners
=
tf
.
einsum
(
"...ni,ji->...nji"
,
dimensions
,
rel_corners
)
# apply rotation matrix on relative corners
rel_corners
=
tf
.
einsum
(
"...nij,...nkj->...nki"
,
rotations
,
rel_corners
)
# translate back to absolute corners format
corners
=
rel_corners
+
tf
.
reshape
(
centers
,
leading_shapes
+
[
1
,
3
])
return
corners
def
_is_on_lefthand_side
(
points
,
v1
,
v2
):
"""Checks if points lay on a vector direction or to its left.
Args:
point: float Tensor of [num_points, 2] of points to check
v1: float Tensor of [num_points, 2] of starting point of the vector
v2: float Tensor of [num_points, 2] of ending point of the vector
Returns:
a boolean Tensor of [num_points] indicate whether each point is on
the left of the vector or on the vector direction.
"""
# Prepare for broadcast: All point operations are on the right,
# and all v1/v2 operations are on the left. This is faster than left/right
# under the assumption that we have more points than vertices.
points_x
=
points
[...,
tf
.
newaxis
,
:,
0
]
points_y
=
points
[...,
tf
.
newaxis
,
:,
1
]
v1_x
=
v1
[...,
0
,
tf
.
newaxis
]
v2_x
=
v2
[...,
0
,
tf
.
newaxis
]
v1_y
=
v1
[...,
1
,
tf
.
newaxis
]
v2_y
=
v2
[...,
1
,
tf
.
newaxis
]
d1
=
(
points_y
-
v1_y
)
*
(
v2_x
-
v1_x
)
d2
=
(
points_x
-
v1_x
)
*
(
v2_y
-
v1_y
)
return
d1
>=
d2
def
_box_area
(
boxes
):
"""Compute the area of 2-d boxes.
Vertices must be ordered counter-clockwise. This function can
technically handle any kind of convex polygons.
Args:
boxes: a float Tensor of [..., 4, 2] of boxes. The last coordinates
are the four corners of the box and (x, y). The corners must be given in
counter-clockwise order.
"""
boxes_roll
=
tf
.
roll
(
boxes
,
shift
=
1
,
axis
=-
2
)
det
=
(
tf
.
reduce_sum
(
boxes
[...,
0
]
*
boxes_roll
[...,
1
]
-
boxes
[...,
1
]
*
boxes_roll
[...,
0
],
axis
=-
1
,
keepdims
=
True
,
)
/
2.0
)
return
tf
.
abs
(
det
)
def
is_within_box2d
(
points
,
boxes
):
"""Checks if 3d points are within 2d bounding boxes.
Currently only xy format is supported.
This function returns true if points are strictly inside the box or on edge.
Args:
points: [num_points, 2] float32 Tensor for 2d points in xy format.
boxes: [num_boxes, 4, 2] float32 Tensor for 2d boxes in xy format,
counter clockwise.
Returns:
boolean Tensor of shape [num_points, num_boxes]
"""
v1
,
v2
,
v3
,
v4
=
(
boxes
[...,
0
,
:],
boxes
[...,
1
,
:],
boxes
[...,
2
,
:],
boxes
[...,
3
,
:],
)
is_inside
=
tf
.
math
.
logical_and
(
tf
.
math
.
logical_and
(
_is_on_lefthand_side
(
points
,
v1
,
v2
),
_is_on_lefthand_side
(
points
,
v2
,
v3
)
),
tf
.
math
.
logical_and
(
_is_on_lefthand_side
(
points
,
v3
,
v4
),
_is_on_lefthand_side
(
points
,
v4
,
v1
)
),
)
valid_area
=
tf
.
greater
(
_box_area
(
boxes
),
0
)
is_inside
=
tf
.
math
.
logical_and
(
is_inside
,
valid_area
)
# swap the last two dimensions
is_inside
=
tf
.
einsum
(
"...ij->...ji"
,
tf
.
cast
(
is_inside
,
tf
.
int32
))
return
tf
.
cast
(
is_inside
,
tf
.
bool
)
def
is_within_box3d
(
points
,
boxes
):
"""Checks if 3d points are within 3d bounding boxes.
Currently only xyz format is supported.
Args:
points: [..., num_points, 3] float32 Tensor for 3d points in xyz format.
boxes: [..., num_boxes, 7] float32 Tensor for 3d boxes in [x, y, z, dx,
dy, dz, phi].
Returns:
boolean Tensor of shape [..., num_points, num_boxes] indicating whether
the point belongs to the box.
"""
# step 1 -- determine if points are within xy range
# convert from center format to corner format
boxes_corner
=
_center_xyzWHD_to_corner_xyz
(
boxes
)
# project to 2d boxes by only taking x, y on top plane
boxes_2d
=
boxes_corner
[...,
0
:
4
,
0
:
2
]
# project to 2d points by only taking x, y
points_2d
=
points
[...,
:
2
]
# check whether points are within 2d boxes, [..., num_points, num_boxes]
is_inside_2d
=
is_within_box2d
(
points_2d
,
boxes_2d
)
# step 2 -- determine if points are within z range
[
_
,
_
,
z
,
_
,
_
,
dz
,
_
]
=
tf
.
split
(
boxes
,
7
,
axis
=-
1
)
z
=
z
[...,
0
]
dz
=
dz
[...,
0
]
bottom
=
z
-
dz
/
2.0
# [..., 1, num_boxes]
bottom
=
bottom
[...,
tf
.
newaxis
,
:]
top
=
z
+
dz
/
2.0
top
=
top
[...,
tf
.
newaxis
,
:]
# [..., num_points, 1]
points_z
=
points
[...,
2
:]
# [..., num_points, num_boxes]
is_inside_z
=
tf
.
math
.
logical_and
(
tf
.
less_equal
(
points_z
,
top
),
tf
.
greater_equal
(
points_z
,
bottom
)
)
return
tf
.
math
.
logical_and
(
is_inside_z
,
is_inside_2d
)
def
coordinate_transform
(
points
,
pose
):
"""
Translate 'points' to coordinates according to 'pose' vector.
pose should contain 6 floating point values:
translate_x, translate_y, translate_z: The translation to apply.
yaw, roll, pitch: The rotation angles in radians.
Args:
points: Float shape [..., 3]: Points to transform to new coordinates.
pose: Float shape [6]: [translate_x, translate_y, translate_z, yaw, roll,
pitch]. The pose in the frame that 'points' comes from, and the definition
of the rotation and translation angles to apply to points.
Returns:
'points' transformed to the coordinates defined by 'pose'.
"""
translate_x
=
pose
[
0
]
translate_y
=
pose
[
1
]
translate_z
=
pose
[
2
]
# Translate the points so the origin is the pose's center.
translation
=
tf
.
reshape
([
translate_x
,
translate_y
,
translate_z
],
shape
=
[
3
])
translated_points
=
points
+
translation
# Compose the rotations along the three axes.
#
# Note: Yaw->Z, Roll->X, Pitch->Y.
yaw
,
roll
,
pitch
=
pose
[
3
],
pose
[
4
],
pose
[
5
]
rotation_matrix
=
_get_3d_rotation_matrix
(
yaw
,
roll
,
pitch
)
# Finally, rotate the points about the pose's origin according to the
# rotation matrix.
rotated_points
=
tf
.
einsum
(
"...i,...ij->...j"
,
translated_points
,
rotation_matrix
)
return
rotated_points
def
spherical_coordinate_transform
(
points
):
"""Converts points from xyz coordinates to spherical coordinates.
https://en.wikipedia.org/wiki/Spherical_coordinate_system#Coordinate_system_conversions
for definitions of the transformations.
Args:
points_xyz: A floating point tensor with shape [..., 3], where the inner 3
dimensions correspond to xyz coordinates.
Returns:
A floating point tensor with the same shape [..., 3], where the inner
dimensions correspond to (dist, theta, phi), where phi corresponds to
azimuth/yaw (rotation around z), and theta corresponds to pitch/inclination
(rotation around y).
"""
dist
=
tf
.
sqrt
(
tf
.
reduce_sum
(
tf
.
square
(
points
),
axis
=-
1
))
theta
=
tf
.
acos
(
points
[...,
2
]
/
tf
.
maximum
(
dist
,
1e-7
))
# Note: tf.atan2 takes in (y, x).
phi
=
tf
.
atan2
(
points
[...,
1
],
points
[...,
0
])
return
tf
.
stack
([
dist
,
theta
,
phi
],
axis
=-
1
)
Keras/keras-cv/keras_cv/ops/point_cloud_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
tensorflow
as
tf
from
absl.testing
import
parameterized
from
keras_cv
import
ops
class
Boxes3DTestCase
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_convert_center_to_corners
(
self
):
boxes
=
tf
.
constant
(
[
[[
1
,
2
,
3
,
4
,
3
,
6
,
0
],
[
1
,
2
,
3
,
4
,
3
,
6
,
0
]],
[[
1
,
2
,
3
,
4
,
3
,
6
,
np
.
pi
/
2.0
],
[
1
,
2
,
3
,
4
,
3
,
6
,
np
.
pi
/
2.0
]],
]
)
corners
=
ops
.
_center_xyzWHD_to_corner_xyz
(
boxes
)
self
.
assertEqual
((
2
,
2
,
8
,
3
),
corners
.
shape
)
for
i
in
[
0
,
1
]:
self
.
assertAllClose
(
-
1
,
np
.
min
(
corners
[
0
,
i
,
:,
0
]))
self
.
assertAllClose
(
3
,
np
.
max
(
corners
[
0
,
i
,
:,
0
]))
self
.
assertAllClose
(
0.5
,
np
.
min
(
corners
[
0
,
i
,
:,
1
]))
self
.
assertAllClose
(
3.5
,
np
.
max
(
corners
[
0
,
i
,
:,
1
]))
self
.
assertAllClose
(
0
,
np
.
min
(
corners
[
0
,
i
,
:,
2
]))
self
.
assertAllClose
(
6
,
np
.
max
(
corners
[
0
,
i
,
:,
2
]))
for
i
in
[
0
,
1
]:
self
.
assertAllClose
(
-
0.5
,
np
.
min
(
corners
[
1
,
i
,
:,
0
]))
self
.
assertAllClose
(
2.5
,
np
.
max
(
corners
[
1
,
i
,
:,
0
]))
self
.
assertAllClose
(
0.0
,
np
.
min
(
corners
[
1
,
i
,
:,
1
]))
self
.
assertAllClose
(
4.0
,
np
.
max
(
corners
[
1
,
i
,
:,
1
]))
self
.
assertAllClose
(
0
,
np
.
min
(
corners
[
1
,
i
,
:,
2
]))
self
.
assertAllClose
(
6
,
np
.
max
(
corners
[
1
,
i
,
:,
2
]))
def
test_within_box2d
(
self
):
boxes
=
tf
.
constant
(
[[[
0.0
,
0.0
],
[
1.0
,
0.0
],
[
1.0
,
1.0
],
[
0.0
,
1.0
]]],
dtype
=
tf
.
float32
)
points
=
tf
.
constant
(
[
[
-
0.5
,
-
0.5
],
[
0.5
,
-
0.5
],
[
1.5
,
-
0.5
],
[
1.5
,
0.5
],
[
1.5
,
1.5
],
[
0.5
,
1.5
],
[
-
0.5
,
1.5
],
[
-
0.5
,
0.5
],
[
1.0
,
1.0
],
[
0.5
,
0.5
],
],
dtype
=
tf
.
float32
,
)
is_inside
=
ops
.
is_within_box2d
(
points
,
boxes
)
expected
=
[[
False
]]
*
8
+
[[
True
]]
*
2
self
.
assertAllEqual
(
expected
,
is_inside
)
def
test_within_zero_box2d
(
self
):
bbox
=
tf
.
constant
(
[[[
0.0
,
0.0
],
[
0.0
,
0.0
],
[
0.0
,
0.0
],
[
0.0
,
0.0
]]],
dtype
=
tf
.
float32
)
points
=
tf
.
constant
(
[
[
-
0.5
,
-
0.5
],
[
0.5
,
-
0.5
],
[
1.5
,
-
0.5
],
[
1.5
,
0.5
],
[
1.5
,
1.5
],
[
0.5
,
1.5
],
[
-
0.5
,
1.5
],
[
-
0.5
,
0.5
],
[
1.0
,
1.0
],
[
0.5
,
0.5
],
],
dtype
=
tf
.
float32
,
)
is_inside
=
ops
.
is_within_box2d
(
points
,
bbox
)
expected
=
[[
False
]]
*
10
self
.
assertAllEqual
(
expected
,
is_inside
)
def
test_is_on_lefthand_side
(
self
):
v1
=
tf
.
constant
([[
0.0
,
0.0
]],
dtype
=
tf
.
float32
)
v2
=
tf
.
constant
([[
1.0
,
0.0
]],
dtype
=
tf
.
float32
)
p
=
tf
.
constant
([[
0.5
,
0.5
],
[
-
1.0
,
-
3
],
[
-
1.0
,
1.0
]],
dtype
=
tf
.
float32
)
res
=
ops
.
_is_on_lefthand_side
(
p
,
v1
,
v2
)
self
.
assertAllEqual
([[
True
,
False
,
True
]],
res
)
res
=
ops
.
_is_on_lefthand_side
(
v1
,
v1
,
v2
)
self
.
assertAllEqual
([[
True
]],
res
)
res
=
ops
.
_is_on_lefthand_side
(
v2
,
v1
,
v2
)
self
.
assertAllEqual
([[
True
]],
res
)
@
parameterized
.
named_parameters
(
(
"without_rotation"
,
0.0
),
(
"with_rotation_1_rad"
,
1.0
),
(
"with_rotation_2_rad"
,
2.0
),
(
"with_rotation_3_rad"
,
3.0
),
)
def
test_box_area
(
self
,
angle
):
boxes
=
tf
.
constant
(
[
[[
0.0
,
0.0
],
[
1.0
,
0.0
],
[
1.0
,
1.0
],
[
0.0
,
1.0
]],
[[
0.0
,
0.0
],
[
2.0
,
0.0
],
[
2.0
,
1.0
],
[
0.0
,
1.0
]],
[[
0.0
,
0.0
],
[
2.0
,
0.0
],
[
2.0
,
2.0
],
[
0.0
,
2.0
]],
],
dtype
=
tf
.
float32
,
)
expected
=
[[
1.0
],
[
2.0
],
[
4.0
]]
def
_rotate
(
bbox
,
theta
):
rotation_matrix
=
tf
.
reshape
(
[
tf
.
cos
(
theta
),
-
tf
.
sin
(
theta
),
tf
.
sin
(
theta
),
tf
.
cos
(
theta
)],
shape
=
(
2
,
2
),
)
return
tf
.
matmul
(
bbox
,
rotation_matrix
)
rotated_bboxes
=
_rotate
(
boxes
,
angle
)
res
=
ops
.
_box_area
(
rotated_bboxes
)
self
.
assertAllClose
(
expected
,
res
)
def
test_within_box3d
(
self
):
num_points
,
num_boxes
=
19
,
4
# rotate the first box by pi / 2 so dim_x and dim_y are swapped.
# The last box is a cube rotated by 45 degrees.
bboxes
=
tf
.
constant
(
[
[
1.0
,
2.0
,
3.0
,
6.0
,
0.4
,
6.0
,
np
.
pi
/
2
],
[
4.0
,
5.0
,
6.0
,
7.0
,
0.8
,
7.0
,
0.0
],
[
0.4
,
0.3
,
0.2
,
0.1
,
0.1
,
0.2
,
0.0
],
[
-
10.0
,
-
10.0
,
-
10.0
,
3.0
,
3.0
,
3.0
,
np
.
pi
/
4
],
],
dtype
=
tf
.
float32
,
)
points
=
tf
.
constant
(
[
[
1.0
,
2.0
,
3.0
],
# box 0 (centroid)
[
0.8
,
2.0
,
3.0
],
# box 0 (below x)
[
1.1
,
2.0
,
3.0
],
# box 0 (above x)
[
1.3
,
2.0
,
3.0
],
# box 0 (too far x)
[
0.7
,
2.0
,
3.0
],
# box 0 (too far x)
[
4.0
,
5.0
,
6.0
],
# box 1 (centroid)
[
4.0
,
4.6
,
6.0
],
# box 1 (below y)
[
4.0
,
5.4
,
6.0
],
# box 1 (above y)
[
4.0
,
4.5
,
6.0
],
# box 1 (too far y)
[
4.0
,
5.5
,
6.0
],
# box 1 (too far y)
[
0.4
,
0.3
,
0.2
],
# box 2 (centroid)
[
0.4
,
0.3
,
0.1
],
# box 2 (below z)
[
0.4
,
0.3
,
0.3
],
# box 2 (above z)
[
0.4
,
0.3
,
0.0
],
# box 2 (too far z)
[
0.4
,
0.3
,
0.4
],
# box 2 (too far z)
[
5.0
,
7.0
,
8.0
],
# none
[
1.0
,
5.0
,
3.6
],
# box0, box1
[
-
11.6
,
-
10.0
,
-
10.0
],
# box3 (rotated corner point).
[
-
11.4
,
-
11.4
,
-
10.0
],
# not in box3, would be if not rotated.
],
dtype
=
tf
.
float32
,
)
expected_is_inside
=
np
.
array
(
[
[
True
,
False
,
False
,
False
],
[
True
,
False
,
False
,
False
],
[
True
,
False
,
False
,
False
],
[
False
,
False
,
False
,
False
],
[
False
,
False
,
False
,
False
],
[
False
,
True
,
False
,
False
],
[
False
,
True
,
False
,
False
],
[
False
,
True
,
False
,
False
],
[
False
,
False
,
False
,
False
],
[
False
,
False
,
False
,
False
],
[
False
,
False
,
True
,
False
],
[
False
,
False
,
True
,
False
],
[
False
,
False
,
True
,
False
],
[
False
,
False
,
False
,
False
],
[
False
,
False
,
False
,
False
],
[
False
,
False
,
False
,
False
],
[
True
,
True
,
False
,
False
],
[
False
,
False
,
False
,
True
],
[
False
,
False
,
False
,
False
],
]
)
assert
points
.
shape
[
0
]
==
num_points
assert
bboxes
.
shape
[
0
]
==
num_boxes
assert
expected_is_inside
.
shape
[
0
]
==
num_points
assert
expected_is_inside
.
shape
[
1
]
==
num_boxes
is_inside
=
ops
.
is_within_box3d
(
points
,
bboxes
)
self
.
assertAllEqual
([
num_points
,
num_boxes
],
is_inside
.
shape
)
self
.
assertAllEqual
(
expected_is_inside
,
is_inside
)
# Add a batch dimension to the data and see that it still works
# as expected.
batch_size
=
3
points
=
tf
.
tile
(
points
[
tf
.
newaxis
,
...],
[
batch_size
,
1
,
1
])
bboxes
=
tf
.
tile
(
bboxes
[
tf
.
newaxis
,
...],
[
batch_size
,
1
,
1
])
is_inside
=
ops
.
is_within_box3d
(
points
,
bboxes
)
self
.
assertAllEqual
([
batch_size
,
num_points
,
num_boxes
],
is_inside
.
shape
)
for
batch_idx
in
range
(
batch_size
):
self
.
assertAllEqual
(
expected_is_inside
,
is_inside
[
batch_idx
])
def
testCoordinateTransform
(
self
):
# This is a validated test case from a real scene.
#
# A single point [1, 1, 3].
point
=
tf
.
constant
(
[[[
5736.94580078
,
1264.85168457
,
45.0271225
]]],
dtype
=
tf
.
float32
)
# Replicate the point to test broadcasting behavior.
replicated_points
=
tf
.
tile
(
point
,
[
2
,
4
,
1
])
# Pose of the car (x, y, z, yaw, roll, pitch).
#
# We negate the translations so that the coordinates are translated
# such that the car is at the origin.
pose
=
tf
.
constant
(
[
-
5728.77148438
,
-
1264.42236328
,
-
45.06399918
,
-
3.10496902
,
0.03288471
,
0.00115049
,
],
dtype
=
tf
.
float32
,
)
result
=
ops
.
coordinate_transform
(
replicated_points
,
pose
)
# We expect the point to be translated close to the car, and then rotated
# mostly around the x-axis.
# the result is device dependent, skip or ignore this test locally if it fails.
expected
=
np
.
tile
([[[
-
8.184512
,
-
0.13086952
,
-
0.04200769
]]],
[
2
,
4
,
1
])
self
.
assertAllClose
(
expected
,
result
)
def
testSphericalCoordinatesTransform
(
self
):
np_xyz
=
np
.
random
.
randn
(
5
,
6
,
3
)
points
=
tf
.
constant
(
np_xyz
,
dtype
=
tf
.
float32
)
spherical_coordinates
=
ops
.
spherical_coordinate_transform
(
points
)
# Convert coordinates back to xyz to verify.
dist
=
spherical_coordinates
[...,
0
]
theta
=
spherical_coordinates
[...,
1
]
phi
=
spherical_coordinates
[...,
2
]
x
=
dist
*
np
.
sin
(
theta
)
*
np
.
cos
(
phi
)
y
=
dist
*
np
.
sin
(
theta
)
*
np
.
sin
(
phi
)
z
=
dist
*
np
.
cos
(
theta
)
self
.
assertAllClose
(
x
,
np_xyz
[...,
0
])
self
.
assertAllClose
(
y
,
np_xyz
[...,
1
])
self
.
assertAllClose
(
z
,
np_xyz
[...,
2
])
Keras/keras-cv/keras_cv/ops/sampling.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
def
balanced_sample
(
positive_matches
:
tf
.
Tensor
,
negative_matches
:
tf
.
Tensor
,
num_samples
:
int
,
positive_fraction
:
float
,
):
"""
Sampling ops to balance positive and negative samples, deals with both
batched and unbatched inputs.
Args:
positive_matches: [N] or [batch_size, N] boolean Tensor, True for
indicating the index is a positive sample
negative_matches: [N] or [batch_size, N] boolean Tensor, True for
indicating the index is a negative sample
num_samples: int, representing the number of samples to collect
positive_fraction: float. 0.5 means positive samples should be half
of all collected samples.
Returns:
selected_indicators: [N] or [batch_size, N]
integer Tensor, 1 for indicating the index is sampled, 0 for
indicating the index is not sampled.
"""
N
=
positive_matches
.
get_shape
().
as_list
()[
-
1
]
if
N
<
num_samples
:
raise
ValueError
(
f
"passed in
{
positive_matches
.
shape
}
has less element than
{
num_samples
}
"
)
# random_val = tf.random.uniform(tf.shape(positive_matches), minval=0., maxval=1.)
zeros
=
tf
.
zeros_like
(
positive_matches
,
dtype
=
tf
.
float32
)
ones
=
tf
.
ones_like
(
positive_matches
,
dtype
=
tf
.
float32
)
ones_rand
=
ones
+
tf
.
random
.
uniform
(
ones
.
shape
,
minval
=-
0.2
,
maxval
=
0.2
)
halfs
=
0.5
*
tf
.
ones_like
(
positive_matches
,
dtype
=
tf
.
float32
)
halfs_rand
=
halfs
+
tf
.
random
.
uniform
(
halfs
.
shape
,
minval
=-
0.2
,
maxval
=
0.2
)
values
=
zeros
values
=
tf
.
where
(
positive_matches
,
ones_rand
,
values
)
values
=
tf
.
where
(
negative_matches
,
halfs_rand
,
values
)
num_pos_samples
=
int
(
num_samples
*
positive_fraction
)
valid_matches
=
tf
.
logical_or
(
positive_matches
,
negative_matches
)
# this might contain negative samples as well
_
,
positive_indices
=
tf
.
math
.
top_k
(
values
,
k
=
num_pos_samples
)
selected_indicators
=
tf
.
cast
(
tf
.
reduce_sum
(
tf
.
one_hot
(
positive_indices
,
depth
=
N
),
axis
=-
2
),
tf
.
bool
)
# setting all selected samples to zeros
values
=
tf
.
where
(
selected_indicators
,
zeros
,
values
)
# setting all excessive positive matches to zeros as well
values
=
tf
.
where
(
positive_matches
,
zeros
,
values
)
num_neg_samples
=
num_samples
-
num_pos_samples
_
,
negative_indices
=
tf
.
math
.
top_k
(
values
,
k
=
num_neg_samples
)
selected_indices
=
tf
.
concat
([
positive_indices
,
negative_indices
],
axis
=-
1
)
selected_indicators
=
tf
.
reduce_sum
(
tf
.
one_hot
(
selected_indices
,
depth
=
N
),
axis
=-
2
)
selected_indicators
=
tf
.
minimum
(
selected_indicators
,
tf
.
ones_like
(
selected_indicators
)
)
selected_indicators
=
tf
.
where
(
valid_matches
,
selected_indicators
,
tf
.
zeros_like
(
selected_indicators
)
)
return
selected_indicators
Keras/keras-cv/keras_cv/ops/sampling_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
keras_cv.ops.sampling
import
balanced_sample
class
BalancedSamplingTest
(
tf
.
test
.
TestCase
):
def
test_balanced_sampling
(
self
):
positive_matches
=
tf
.
constant
(
[
True
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
]
)
negative_matches
=
tf
.
constant
(
[
False
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
]
)
num_samples
=
5
positive_fraction
=
0.2
res
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
# The 1st element must be selected, given it's the only one.
self
.
assertAllClose
(
res
[
0
],
1
)
def
test_balanced_batched_sampling
(
self
):
positive_matches
=
tf
.
constant
(
[
[
True
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
],
[
False
,
False
,
False
,
False
,
False
,
False
,
True
,
False
,
False
,
False
],
]
)
negative_matches
=
tf
.
constant
(
[
[
False
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
True
,
True
,
True
,
False
,
True
,
True
,
True
],
]
)
num_samples
=
5
positive_fraction
=
0.2
res
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
# the 1st element from the 1st batch must be selected, given it's the only one
self
.
assertAllClose
(
res
[
0
][
0
],
1
)
# the 7th element from the 2nd batch must be selected, given it's the only one
self
.
assertAllClose
(
res
[
1
][
6
],
1
)
def
test_balanced_sampling_over_positive_fraction
(
self
):
positive_matches
=
tf
.
constant
(
[
True
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
]
)
negative_matches
=
tf
.
constant
(
[
False
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
]
)
num_samples
=
5
positive_fraction
=
0.4
res
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
# only 1 positive sample exists, thus it is chosen
self
.
assertAllClose
(
res
[
0
],
1
)
def
test_balanced_sampling_under_positive_fraction
(
self
):
positive_matches
=
tf
.
constant
(
[
True
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
]
)
negative_matches
=
tf
.
constant
(
[
False
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
]
)
num_samples
=
5
positive_fraction
=
0.1
res
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
# no positive is chosen
self
.
assertAllClose
(
res
[
0
],
0
)
self
.
assertAllClose
(
tf
.
reduce_sum
(
res
),
5
)
def
test_balanced_sampling_over_num_samples
(
self
):
positive_matches
=
tf
.
constant
(
[
True
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
]
)
negative_matches
=
tf
.
constant
(
[
False
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
,
True
]
)
# users want to get 20 samples, but only 10 are available
num_samples
=
20
positive_fraction
=
0.1
with
self
.
assertRaisesRegex
(
ValueError
,
"has less element"
):
_
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
def
test_balanced_sampling_no_positive
(
self
):
positive_matches
=
tf
.
constant
(
[
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
]
)
# the rest are neither positive nor negative, but ignord matches
negative_matches
=
tf
.
constant
(
[
False
,
False
,
True
,
False
,
False
,
True
,
False
,
False
,
True
,
False
]
)
num_samples
=
5
positive_fraction
=
0.5
res
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
# given only 3 negative and 0 positive, select all of them
self
.
assertAllClose
(
res
,
[
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
])
def
test_balanced_sampling_no_negative
(
self
):
positive_matches
=
tf
.
constant
(
[
True
,
True
,
False
,
False
,
False
,
False
,
False
,
False
,
False
,
False
]
)
# 2-9 indices are neither positive nor negative, they're ignored matches
negative_matches
=
tf
.
constant
([
False
]
*
10
)
num_samples
=
5
positive_fraction
=
0.5
res
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
# given only 2 positive and 0 negative, select all of them.
self
.
assertAllClose
(
res
,
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
])
def
test_balanced_sampling_many_samples
(
self
):
positive_matches
=
tf
.
random
.
uniform
(
[
2
,
1000
],
minval
=
0
,
maxval
=
1
,
dtype
=
tf
.
float32
)
positive_matches
=
positive_matches
>
0.98
negative_matches
=
tf
.
logical_not
(
positive_matches
)
num_samples
=
256
positive_fraction
=
0.25
_
=
balanced_sample
(
positive_matches
,
negative_matches
,
num_samples
,
positive_fraction
)
Keras/keras-cv/keras_cv/ops/target_gather.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
import
tensorflow
as
tf
def
_target_gather
(
targets
:
tf
.
Tensor
,
indices
:
tf
.
Tensor
,
mask
:
Optional
[
tf
.
Tensor
]
=
None
,
mask_val
:
Optional
[
float
]
=
0.0
,
):
"""A utility function wrapping tf.gather, which deals with:
1) both batched and unbatched `targets`
2) when unbatched `targets` have empty rows, the result will be filled
with `mask_val`
3) target masking.
Args:
targets: [N, ...] or [batch_size, N, ...] Tensor representing targets such
as boxes, keypoints, etc.
indices: [M] or [batch_size, M] int32 Tensor representing indices within
`targets` to gather.
mask: optional [M, ...] or [batch_size, M, ...] boolean
Tensor representing the masking for each target. `True` means the corresponding
entity should be masked to `mask_val`, `False` means the corresponding entity
should be the target value.
mask_val: optinal float representing the masking value if `mask` is True on
the entity.
Returns:
targets: [M, ...] or [batch_size, M, ...] Tensor representing selected targets.
Raise:
ValueError: If `targets` is higher than rank 3.
"""
targets_shape
=
targets
.
get_shape
().
as_list
()
if
len
(
targets_shape
)
>
3
:
raise
ValueError
(
"`target_gather` does not support `targets` with rank "
"larger than 3, got {}"
.
format
(
len
(
targets
.
shape
))
)
def
_gather_unbatched
(
labels
,
match_indices
,
mask
,
mask_val
):
"""Gather based on unbatched labels and boxes."""
num_gt_boxes
=
tf
.
shape
(
labels
)[
0
]
def
_assign_when_rows_empty
():
if
len
(
labels
.
shape
)
>
1
:
mask_shape
=
[
match_indices
.
shape
[
0
],
labels
.
shape
[
-
1
]]
else
:
mask_shape
=
[
match_indices
.
shape
[
0
]]
return
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones
(
mask_shape
,
dtype
=
labels
.
dtype
)
def
_assign_when_rows_not_empty
():
targets
=
tf
.
gather
(
labels
,
match_indices
)
if
mask
is
None
:
return
targets
else
:
masked_targets
=
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones_like
(
mask
,
dtype
=
labels
.
dtype
)
return
tf
.
where
(
mask
,
masked_targets
,
targets
)
return
tf
.
cond
(
tf
.
greater
(
num_gt_boxes
,
0
),
_assign_when_rows_not_empty
,
_assign_when_rows_empty
,
)
def
_gather_batched
(
labels
,
match_indices
,
mask
,
mask_val
):
"""Gather based on batched labels."""
batch_size
=
labels
.
shape
[
0
]
if
batch_size
==
1
:
if
mask
is
not
None
:
result
=
_gather_unbatched
(
tf
.
squeeze
(
labels
,
axis
=
0
),
tf
.
squeeze
(
match_indices
,
axis
=
0
),
tf
.
squeeze
(
mask
,
axis
=
0
),
mask_val
,
)
else
:
result
=
_gather_unbatched
(
tf
.
squeeze
(
labels
,
axis
=
0
),
tf
.
squeeze
(
match_indices
,
axis
=
0
),
None
,
mask_val
,
)
return
tf
.
expand_dims
(
result
,
axis
=
0
)
else
:
indices_shape
=
tf
.
shape
(
match_indices
)
indices_dtype
=
match_indices
.
dtype
batch_indices
=
tf
.
expand_dims
(
tf
.
range
(
indices_shape
[
0
],
dtype
=
indices_dtype
),
axis
=-
1
)
*
tf
.
ones
([
1
,
indices_shape
[
-
1
]],
dtype
=
indices_dtype
)
gather_nd_indices
=
tf
.
stack
([
batch_indices
,
match_indices
],
axis
=-
1
)
targets
=
tf
.
gather_nd
(
labels
,
gather_nd_indices
)
if
mask
is
None
:
return
targets
else
:
masked_targets
=
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones_like
(
mask
,
dtype
=
labels
.
dtype
)
return
tf
.
where
(
mask
,
masked_targets
,
targets
)
if
len
(
targets_shape
)
<=
2
:
return
_gather_unbatched
(
targets
,
indices
,
mask
,
mask_val
)
elif
len
(
targets_shape
)
==
3
:
return
_gather_batched
(
targets
,
indices
,
mask
,
mask_val
)
Keras/keras-cv/keras_cv/ops/target_gather_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
keras_cv.ops.target_gather
import
_target_gather
class
TargetGatherTest
(
tf
.
test
.
TestCase
):
def
test_target_gather_boxes_batched
(
self
):
target_boxes
=
tf
.
constant
(
[[
0
,
0
,
5
,
5
],
[
0
,
5
,
5
,
10
],
[
5
,
0
,
10
,
5
],
[
5
,
5
,
10
,
10
]]
)
target_boxes
=
target_boxes
[
tf
.
newaxis
,
...]
indices
=
tf
.
constant
([[
0
,
2
]],
dtype
=
tf
.
int32
)
expected_boxes
=
tf
.
constant
([[
0
,
0
,
5
,
5
],
[
5
,
0
,
10
,
5
]])
expected_boxes
=
expected_boxes
[
tf
.
newaxis
,
...]
res
=
_target_gather
(
target_boxes
,
indices
)
self
.
assertAllClose
(
expected_boxes
,
res
)
def
test_target_gather_boxes_unbatched
(
self
):
target_boxes
=
tf
.
constant
(
[[
0
,
0
,
5
,
5
],
[
0
,
5
,
5
,
10
],
[
5
,
0
,
10
,
5
],
[
5
,
5
,
10
,
10
]]
)
indices
=
tf
.
constant
([
0
,
2
],
dtype
=
tf
.
int32
)
expected_boxes
=
tf
.
constant
([[
0
,
0
,
5
,
5
],
[
5
,
0
,
10
,
5
]])
res
=
_target_gather
(
target_boxes
,
indices
)
self
.
assertAllClose
(
expected_boxes
,
res
)
def
test_target_gather_classes_batched
(
self
):
target_classes
=
tf
.
constant
([[
1
,
2
,
3
,
4
]])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([[
0
,
2
]],
dtype
=
tf
.
int32
)
expected_classes
=
tf
.
constant
([[
1
,
3
]])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_classes_unbatched
(
self
):
target_classes
=
tf
.
constant
([
1
,
2
,
3
,
4
])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([
0
,
2
],
dtype
=
tf
.
int32
)
expected_classes
=
tf
.
constant
([
1
,
3
])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_classes_batched_with_mask
(
self
):
target_classes
=
tf
.
constant
([[
1
,
2
,
3
,
4
]])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([[
0
,
2
]],
dtype
=
tf
.
int32
)
masks
=
tf
.
constant
(([[
False
,
True
]]))
masks
=
masks
[...,
tf
.
newaxis
]
# the second element is masked
expected_classes
=
tf
.
constant
([[
1
,
0
]])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
,
masks
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_classes_batched_with_mask_val
(
self
):
target_classes
=
tf
.
constant
([[
1
,
2
,
3
,
4
]])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([[
0
,
2
]],
dtype
=
tf
.
int32
)
masks
=
tf
.
constant
(([[
False
,
True
]]))
masks
=
masks
[...,
tf
.
newaxis
]
# the second element is masked
expected_classes
=
tf
.
constant
([[
1
,
-
1
]])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
,
masks
,
-
1
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_classes_unbatched_with_mask
(
self
):
target_classes
=
tf
.
constant
([
1
,
2
,
3
,
4
])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([
0
,
2
],
dtype
=
tf
.
int32
)
masks
=
tf
.
constant
([
False
,
True
])
masks
=
masks
[...,
tf
.
newaxis
]
expected_classes
=
tf
.
constant
([
1
,
0
])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
,
masks
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_with_empty_targets
(
self
):
target_classes
=
tf
.
constant
([])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([
0
,
2
],
dtype
=
tf
.
int32
)
# return all 0s since input is empty
expected_classes
=
tf
.
constant
([
0
,
0
])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_classes_multi_batch
(
self
):
target_classes
=
tf
.
constant
([[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
]])
target_classes
=
target_classes
[...,
tf
.
newaxis
]
indices
=
tf
.
constant
([[
0
,
2
],
[
1
,
3
]],
dtype
=
tf
.
int32
)
expected_classes
=
tf
.
constant
([[
1
,
3
],
[
6
,
8
]])
expected_classes
=
expected_classes
[...,
tf
.
newaxis
]
res
=
_target_gather
(
target_classes
,
indices
)
self
.
assertAllClose
(
expected_classes
,
res
)
def
test_target_gather_invalid_rank
(
self
):
targets
=
tf
.
random
.
normal
([
32
,
2
,
2
,
2
])
indices
=
tf
.
constant
([
0
,
1
],
dtype
=
tf
.
int32
)
with
self
.
assertRaisesRegex
(
ValueError
,
"larger than 3"
):
_
=
_target_gather
(
targets
,
indices
)
Keras/keras-cv/keras_cv/training/__init__.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
keras_cv.training.contrastive.contrastive_trainer
import
ContrastiveTrainer
from
keras_cv.training.contrastive.simclr_trainer
import
SimCLRAugmenter
from
keras_cv.training.contrastive.simclr_trainer
import
SimCLRTrainer
Keras/keras-cv/keras_cv/training/contrastive/__init__.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Keras/keras-cv/keras_cv/training/contrastive/contrastive_trainer.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
tensorflow
import
keras
from
keras_cv.utils.train
import
convert_inputs_to_tf_dataset
class
ContrastiveTrainer
(
keras
.
Model
):
"""Creates a self-supervised contrastive trainer for a model.
Args:
encoder: a `keras.Model` to be pre-trained. In most cases, this encoder
should not include a top dense layer.
augmenter: a preprocessing layer to randomly augment input images for contrastive learning,
or a tuple of two separate augmenters for the two sides of the contrastive pipeline.
projector: a projection model for contrastive training, or a tuple of two separate
projectors for the two sides of the contrastive pipeline. This shrinks
the feature map produced by the encoder, and is usually a 1 or
2-layer dense MLP.
probe: An optional Keras layer or model which will be trained against
class labels at train-time using the encoder output as input.
Note that this should be specified iff training with labeled images.
This predicts class labels based on the feature map produced by the
encoder and is usually a 1 or 2-layer dense MLP.
Returns:
A `keras.Model` instance.
Usage:
```python
encoder = keras_cv.models.DenseNet121(include_rescaling=True, include_top=False, pooling="avg")
augmenter = keras_cv.layers.preprocessing.RandomFlip()
projector = keras.layers.Dense(64)
probe = keras_cv.training.ContrastiveTrainer.linear_probe(classes=10)
trainer = keras_cv.training.ContrastiveTrainer(
encoder=encoder,
augmenter=augmenter,
projector=projector,
probe=probe
)
trainer.compile(
encoder_optimizer=keras.optimizers.Adam(),
encoder_loss=keras_cv.losses.SimCLRLoss(temperature=0.5),
probe_optimizer=keras.optimizers.Adam(),
probe_loss=keras.losses.CategoricalCrossentropy(from_logits=True),
probe_metrics=[keras.metrics.CategoricalAccuracy(name="probe_accuracy")]
)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, 10)
trainer.fit(x_train, y_train)
```
"""
def
__init__
(
self
,
encoder
,
augmenter
,
projector
,
probe
=
None
,
):
super
().
__init__
()
if
encoder
.
output
.
shape
.
rank
!=
2
:
raise
ValueError
(
f
"`encoder` must have a flattened output. Expected rank(encoder.output.shape)=2, got encoder.output.shape=
{
encoder
.
output
.
shape
}
"
)
if
type
(
augmenter
)
is
tuple
and
len
(
augmenter
)
!=
2
:
raise
ValueError
(
"`augmenter` must be either a single augmenter or a tuple of exactly 2 augmenters."
)
if
type
(
projector
)
is
tuple
and
len
(
projector
)
!=
2
:
raise
ValueError
(
"`projector` must be either a single augmenter or a tuple of exactly 2 augmenters."
)
self
.
augmenters
=
(
augmenter
if
type
(
augmenter
)
is
tuple
else
(
augmenter
,
augmenter
)
)
self
.
encoder
=
encoder
self
.
projectors
=
(
projector
if
type
(
projector
)
is
tuple
else
(
projector
,
projector
)
)
self
.
probe
=
probe
self
.
loss_metric
=
keras
.
metrics
.
Mean
(
name
=
"loss"
)
if
probe
is
not
None
:
self
.
probe_loss_metric
=
keras
.
metrics
.
Mean
(
name
=
"probe_loss"
)
self
.
probe_metrics
=
[]
def
compile
(
self
,
encoder_loss
,
encoder_optimizer
,
encoder_metrics
=
None
,
probe_optimizer
=
None
,
probe_loss
=
None
,
probe_metrics
=
None
,
**
kwargs
,
):
super
().
compile
(
loss
=
encoder_loss
,
optimizer
=
encoder_optimizer
,
metrics
=
encoder_metrics
,
**
kwargs
,
)
if
self
.
probe
and
not
probe_optimizer
:
raise
ValueError
(
"`probe_optimizer` must be specified when a probe is included."
)
if
self
.
probe
and
not
probe_loss
:
raise
ValueError
(
"`probe_loss` must be specified when a probe is included."
)
if
"loss"
in
kwargs
:
raise
ValueError
(
"`loss` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_loss` or `probe_loss`."
)
if
"optimizer"
in
kwargs
:
raise
ValueError
(
"`optimizer` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_optimizer` or `probe_optimizer`."
)
if
"metrics"
in
kwargs
:
raise
ValueError
(
"`metrics` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_metrics` or `probe_metrics`."
)
if
self
.
probe
:
self
.
probe_loss
=
probe_loss
self
.
probe_optimizer
=
probe_optimizer
self
.
probe_metrics
=
probe_metrics
or
[]
@
property
def
metrics
(
self
):
metrics
=
[
self
.
loss_metric
,
]
if
self
.
probe
:
metrics
+=
[
self
.
probe_loss_metric
]
metrics
+=
self
.
probe_metrics
return
super
().
metrics
+
metrics
def
fit
(
self
,
x
=
None
,
y
=
None
,
sample_weight
=
None
,
batch_size
=
None
,
**
kwargs
,
):
dataset
=
convert_inputs_to_tf_dataset
(
x
=
x
,
y
=
y
,
sample_weight
=
sample_weight
,
batch_size
=
batch_size
)
dataset
=
dataset
.
map
(
self
.
run_augmenters
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
AUTOTUNE
)
return
super
().
fit
(
x
=
dataset
,
**
kwargs
)
def
run_augmenters
(
self
,
x
,
y
=
None
):
inputs
=
{
"images"
:
x
}
if
y
is
not
None
:
inputs
[
"labels"
]
=
y
inputs
[
"augmented_images_0"
]
=
self
.
augmenters
[
0
](
x
,
training
=
True
)
inputs
[
"augmented_images_1"
]
=
self
.
augmenters
[
1
](
x
,
training
=
True
)
return
inputs
def
train_step
(
self
,
data
):
images
=
data
[
"images"
]
labels
=
data
[
"labels"
]
if
"labels"
in
data
else
None
augmented_images_0
=
data
[
"augmented_images_0"
]
augmented_images_1
=
data
[
"augmented_images_1"
]
with
tf
.
GradientTape
()
as
tape
:
features_0
=
self
.
encoder
(
augmented_images_0
,
training
=
True
)
features_1
=
self
.
encoder
(
augmented_images_1
,
training
=
True
)
projections_0
=
self
.
projectors
[
0
](
features_0
,
training
=
True
)
projections_1
=
self
.
projectors
[
1
](
features_1
,
training
=
True
)
loss
=
self
.
compiled_loss
(
projections_0
,
projections_1
,
regularization_losses
=
self
.
encoder
.
losses
)
gradients
=
tape
.
gradient
(
loss
,
self
.
encoder
.
trainable_weights
+
self
.
projectors
[
0
].
trainable_weights
+
self
.
projectors
[
1
].
trainable_weights
,
)
self
.
optimizer
.
apply_gradients
(
zip
(
gradients
,
self
.
encoder
.
trainable_weights
+
self
.
projectors
[
0
].
trainable_weights
+
self
.
projectors
[
1
].
trainable_weights
,
)
)
self
.
loss_metric
.
update_state
(
loss
)
if
self
.
probe
:
if
labels
is
None
:
raise
ValueError
(
"Targets must be provided when a probe is specified"
)
with
tf
.
GradientTape
()
as
tape
:
features
=
tf
.
stop_gradient
(
self
.
encoder
(
images
,
training
=
False
))
class_logits
=
self
.
probe
(
features
,
training
=
True
)
probe_loss
=
self
.
probe_loss
(
labels
,
class_logits
)
gradients
=
tape
.
gradient
(
probe_loss
,
self
.
probe
.
trainable_weights
)
self
.
probe_optimizer
.
apply_gradients
(
zip
(
gradients
,
self
.
probe
.
trainable_weights
)
)
self
.
probe_loss_metric
.
update_state
(
probe_loss
)
for
metric
in
self
.
probe_metrics
:
metric
.
update_state
(
labels
,
class_logits
)
return
{
metric
.
name
:
metric
.
result
()
for
metric
in
self
.
metrics
}
def
call
(
self
,
inputs
):
raise
NotImplementedError
(
"ContrastiveTrainer.call() is not implemented - please call your model directly."
)
@
staticmethod
def
linear_probe
(
classes
,
**
kwargs
):
return
keras
.
Sequential
(
keras
.
layers
.
Dense
(
classes
),
**
kwargs
)
Keras/keras-cv/keras_cv/training/contrastive/contrastive_trainer_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
tensorflow
import
keras
from
tensorflow.keras
import
layers
from
tensorflow.keras
import
metrics
from
tensorflow.keras
import
optimizers
from
keras_cv.layers
import
preprocessing
from
keras_cv.losses
import
SimCLRLoss
from
keras_cv.models
import
DenseNet121
from
keras_cv.training
import
ContrastiveTrainer
class
ContrastiveTrainerTest
(
tf
.
test
.
TestCase
):
def
test_probe_requires_probe_optimizer
(
self
):
trainer
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
self
.
build_probe
(),
)
with
self
.
assertRaises
(
ValueError
):
trainer
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
)
def
test_targets_required_if_probing
(
self
):
trainer_with_probing
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
self
.
build_probe
(),
)
trainer_without_probing
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
None
,
)
images
=
tf
.
random
.
uniform
((
1
,
50
,
50
,
3
))
trainer_with_probing
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
probe_optimizer
=
optimizers
.
Adam
(),
probe_loss
=
keras
.
losses
.
CategoricalCrossentropy
(
from_logits
=
True
),
)
trainer_without_probing
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
)
with
self
.
assertRaises
(
ValueError
):
trainer_with_probing
.
fit
(
images
)
def
test_train_with_probing
(
self
):
trainer_with_probing
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
self
.
build_probe
(
classes
=
20
),
)
images
=
tf
.
random
.
uniform
((
1
,
50
,
50
,
3
))
targets
=
tf
.
ones
((
1
,
20
))
trainer_with_probing
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
probe_metrics
=
[
metrics
.
TopKCategoricalAccuracy
(
3
,
"top3_probe_accuracy"
)],
probe_optimizer
=
optimizers
.
Adam
(),
probe_loss
=
keras
.
losses
.
CategoricalCrossentropy
(
from_logits
=
True
),
)
trainer_with_probing
.
fit
(
images
,
targets
)
def
test_train_without_probing
(
self
):
trainer_without_probing
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
None
,
)
images
=
tf
.
random
.
uniform
((
1
,
50
,
50
,
3
))
targets
=
tf
.
ones
((
1
,
20
))
trainer_without_probing
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
)
trainer_without_probing
.
fit
(
images
)
trainer_without_probing
.
fit
(
images
,
targets
)
def
test_inference_not_supported
(
self
):
trainer
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
None
,
)
trainer
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
)
with
self
.
assertRaises
(
NotImplementedError
):
trainer
(
tf
.
ones
((
1
,
50
,
50
,
3
)))
def
test_encoder_must_have_flat_output
(
self
):
with
self
.
assertRaises
(
ValueError
):
_
=
ContrastiveTrainer
(
# A DenseNet without pooling does not have a flat output
encoder
=
DenseNet121
(
include_rescaling
=
False
,
include_top
=
False
),
augmenter
=
self
.
build_augmenter
(),
projector
=
self
.
build_projector
(),
probe
=
None
,
)
def
test_with_multiple_augmenters_and_projectors
(
self
):
augmenter0
=
preprocessing
.
RandomFlip
(
"horizontal"
)
augmenter1
=
preprocessing
.
RandomFlip
(
"vertical"
)
projector0
=
layers
.
Dense
(
64
,
name
=
"projector0"
)
projector1
=
keras
.
Sequential
(
[
projector0
,
layers
.
ReLU
(),
layers
.
Dense
(
64
,
name
=
"projector1"
)]
)
trainer_without_probing
=
ContrastiveTrainer
(
encoder
=
self
.
build_encoder
(),
augmenter
=
(
augmenter0
,
augmenter1
),
projector
=
(
projector0
,
projector1
),
probe
=
None
,
)
images
=
tf
.
random
.
uniform
((
1
,
50
,
50
,
3
))
trainer_without_probing
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
)
trainer_without_probing
.
fit
(
images
)
def
build_augmenter
(
self
):
return
preprocessing
.
RandomFlip
(
"horizontal"
)
def
build_encoder
(
self
):
return
DenseNet121
(
include_rescaling
=
False
,
include_top
=
False
,
pooling
=
"avg"
)
def
build_projector
(
self
):
return
layers
.
Dense
(
128
)
def
build_probe
(
self
,
classes
=
20
):
return
layers
.
Dense
(
classes
)
Keras/keras-cv/keras_cv/training/contrastive/simclr_trainer.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
tensorflow
import
keras
from
tensorflow.keras
import
layers
from
keras_cv.layers
import
preprocessing
from
keras_cv.training
import
ContrastiveTrainer
class
SimCLRTrainer
(
ContrastiveTrainer
):
"""Creates a SimCLRTrainer.
References:
- [SimCLR paper](https://arxiv.org/pdf/2002.05709)
Args:
encoder: a `keras.Model` to be pre-trained. In most cases, this encoder
should not include a top dense layer.
augmenter: a SimCLRAugmenter layer to randomly augment input
images for contrastive learning
projection_width: the width of the two-layer dense model used for
projection in the SimCLR paper
"""
def
__init__
(
self
,
encoder
,
augmenter
,
projection_width
=
128
,
**
kwargs
):
super
().
__init__
(
encoder
=
encoder
,
augmenter
=
augmenter
,
projector
=
keras
.
Sequential
(
[
layers
.
Dense
(
projection_width
,
activation
=
"relu"
),
layers
.
Dense
(
projection_width
),
layers
.
BatchNormalization
(),
],
name
=
"projector"
,
),
**
kwargs
,
)
class
SimCLRAugmenter
(
preprocessing
.
Augmenter
):
def
__init__
(
self
,
value_range
,
height
=
128
,
width
=
128
,
crop_area_factor
=
(
0.08
,
1.0
),
aspect_ratio_factor
=
(
3
/
4
,
4
/
3
),
grayscale_rate
=
0.2
,
color_jitter_rate
=
0.8
,
brightness_factor
=
0.2
,
contrast_factor
=
0.8
,
saturation_factor
=
(
0.3
,
0.7
),
hue_factor
=
0.2
,
**
kwargs
,
):
return
super
().
__init__
(
[
preprocessing
.
RandomFlip
(
"horizontal"
),
preprocessing
.
RandomCropAndResize
(
target_size
=
(
height
,
width
),
crop_area_factor
=
crop_area_factor
,
aspect_ratio_factor
=
aspect_ratio_factor
,
),
preprocessing
.
MaybeApply
(
preprocessing
.
Grayscale
(
output_channels
=
3
),
rate
=
grayscale_rate
),
preprocessing
.
MaybeApply
(
preprocessing
.
RandomColorJitter
(
value_range
=
value_range
,
brightness_factor
=
brightness_factor
,
contrast_factor
=
contrast_factor
,
saturation_factor
=
saturation_factor
,
hue_factor
=
hue_factor
,
),
rate
=
color_jitter_rate
,
),
],
**
kwargs
,
)
Keras/keras-cv/keras_cv/training/contrastive/simclr_trainer_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
tensorflow.keras
import
optimizers
from
keras_cv.losses
import
SimCLRLoss
from
keras_cv.models
import
ResNet50V2
from
keras_cv.training
import
SimCLRAugmenter
from
keras_cv.training
import
SimCLRTrainer
class
SimCLRTrainerTest
(
tf
.
test
.
TestCase
):
def
test_train_without_probing
(
self
):
simclr_without_probing
=
SimCLRTrainer
(
self
.
build_encoder
(),
augmenter
=
SimCLRAugmenter
(
value_range
=
(
0
,
255
)),
)
images
=
tf
.
random
.
uniform
((
10
,
512
,
512
,
3
))
simclr_without_probing
.
compile
(
encoder_optimizer
=
optimizers
.
Adam
(),
encoder_loss
=
SimCLRLoss
(
temperature
=
0.5
),
)
simclr_without_probing
.
fit
(
images
)
def
build_encoder
(
self
):
return
ResNet50V2
(
include_rescaling
=
False
,
include_top
=
False
,
pooling
=
"avg"
)
Keras/keras-cv/keras_cv/utils/__init__.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
keras_cv.utils.fill_utils
import
fill_rectangle
from
keras_cv.utils.preprocessing
import
blend
from
keras_cv.utils.preprocessing
import
parse_factor
from
keras_cv.utils.preprocessing
import
transform
from
keras_cv.utils.preprocessing
import
transform_value_range
from
keras_cv.utils.train
import
convert_inputs_to_tf_dataset
from
keras_cv.utils.train
import
scale_loss_for_distribution
Keras/keras-cv/keras_cv/utils/conv_utils.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
normalize_tuple
(
value
,
n
,
name
,
allow_zero
=
False
):
"""Transforms non-negative/positive integer/integers into an integer tuple.
Args:
value: The value to validate and convert. Could an int, or any iterable of
ints.
n: The size of the tuple to be returned.
name: The name of the argument being validated, e.g. "strides" or
"kernel_size". This is only used to format error messages.
allow_zero: Default to False. A ValueError will raised if zero is received
and this param is False.
Returns:
A tuple of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof or a
negative value is
passed.
"""
error_msg
=
(
f
"The `
{
name
}
` argument must be a tuple of
{
n
}
"
f
"integers. Received:
{
value
}
"
)
if
isinstance
(
value
,
int
):
value_tuple
=
(
value
,)
*
n
else
:
try
:
value_tuple
=
tuple
(
value
)
except
TypeError
:
raise
ValueError
(
error_msg
)
if
len
(
value_tuple
)
!=
n
:
raise
ValueError
(
error_msg
)
for
single_value
in
value_tuple
:
try
:
int
(
single_value
)
except
(
ValueError
,
TypeError
):
error_msg
+=
(
f
"including element
{
single_value
}
of "
f
"type
{
type
(
single_value
)
}
"
)
raise
ValueError
(
error_msg
)
if
allow_zero
:
unqualified_values
=
{
v
for
v
in
value_tuple
if
v
<
0
}
req_msg
=
">= 0"
else
:
unqualified_values
=
{
v
for
v
in
value_tuple
if
v
<=
0
}
req_msg
=
"> 0"
if
unqualified_values
:
error_msg
+=
(
f
" including
{
unqualified_values
}
"
f
" that does not satisfy the requirement `
{
req_msg
}
`."
)
raise
ValueError
(
error_msg
)
return
value_tuple
Keras/keras-cv/keras_cv/utils/fill_utils.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
keras_cv
import
bounding_box
def
_axis_mask
(
starts
,
ends
,
mask_len
):
# index range of axis
batch_size
=
tf
.
shape
(
starts
)[
0
]
axis_indices
=
tf
.
range
(
mask_len
,
dtype
=
starts
.
dtype
)
axis_indices
=
tf
.
expand_dims
(
axis_indices
,
0
)
axis_indices
=
tf
.
tile
(
axis_indices
,
[
batch_size
,
1
])
# mask of index bounds
axis_mask
=
tf
.
greater_equal
(
axis_indices
,
starts
)
&
tf
.
less
(
axis_indices
,
ends
)
return
axis_mask
def
corners_to_mask
(
bounding_boxes
,
mask_shape
):
"""Converts bounding boxes in corners format to boolean masks
Args:
bounding_boxes: tensor of rectangle coordinates with shape (batch_size, 4) in
corners format (x0, y0, x1, y1).
mask_shape: a shape tuple as (width, height) indicating the output
width and height of masks.
Returns:
boolean masks with shape (batch_size, width, height) where True values
indicate positions within bounding box coordinates.
"""
mask_width
,
mask_height
=
mask_shape
x0
,
y0
,
x1
,
y1
=
tf
.
split
(
bounding_boxes
,
[
1
,
1
,
1
,
1
],
axis
=-
1
)
w_mask
=
_axis_mask
(
x0
,
x1
,
mask_width
)
h_mask
=
_axis_mask
(
y0
,
y1
,
mask_height
)
w_mask
=
tf
.
expand_dims
(
w_mask
,
axis
=
1
)
h_mask
=
tf
.
expand_dims
(
h_mask
,
axis
=
2
)
masks
=
tf
.
logical_and
(
w_mask
,
h_mask
)
return
masks
def
fill_rectangle
(
images
,
centers_x
,
centers_y
,
widths
,
heights
,
fill_values
):
"""Fill rectangles with fill value into images.
Args:
images: Tensor of images to fill rectangles into.
centers_x: Tensor of positions of the rectangle centers on the x-axis.
centers_y: Tensor of positions of the rectangle centers on the y-axis.
widths: Tensor of widths of the rectangles
heights: Tensor of heights of the rectangles
fill_values: Tensor with same shape as images to get rectangle fill from.
Returns:
images with filled rectangles.
"""
images_shape
=
tf
.
shape
(
images
)
images_height
=
images_shape
[
1
]
images_width
=
images_shape
[
2
]
xywh
=
tf
.
stack
([
centers_x
,
centers_y
,
widths
,
heights
],
axis
=
1
)
xywh
=
tf
.
cast
(
xywh
,
tf
.
float32
)
corners
=
bounding_box
.
convert_format
(
xywh
,
source
=
"center_xywh"
,
target
=
"xyxy"
)
mask_shape
=
(
images_width
,
images_height
)
is_rectangle
=
corners_to_mask
(
corners
,
mask_shape
)
is_rectangle
=
tf
.
expand_dims
(
is_rectangle
,
-
1
)
images
=
tf
.
where
(
is_rectangle
,
fill_values
,
images
)
return
images
Keras/keras-cv/keras_cv/utils/fill_utils_test.py
0 → 100644
View file @
1a3c83d6
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
keras_cv.utils
import
fill_utils
class
BoundingBoxToMaskTest
(
tf
.
test
.
TestCase
):
def
_run_test
(
self
,
corners
,
expected
):
mask
=
fill_utils
.
corners_to_mask
(
corners
,
mask_shape
=
(
6
,
6
))
mask
=
tf
.
cast
(
mask
,
dtype
=
tf
.
int32
)
tf
.
assert_equal
(
mask
,
expected
)
def
test_corners_whole
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
],
dtype
=
tf
.
int32
,
)
corners
=
tf
.
constant
([[
1
,
0
,
4
,
3
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_corners_frac
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
1
,
1
,
1
,
0
],
[
0
,
0
,
1
,
1
,
1
,
0
],
[
0
,
0
,
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
1.5
,
0.5
,
4.5
,
3.5
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_width_zero
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
0
,
0
,
0
,
3
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_height_zero
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
1
,
0
,
4
,
0
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_width_negative
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
1
,
0
,
-
2
,
3
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_height_negative
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
1
,
0
,
4
,
-
2
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_width_out_of_lower_bound
(
self
):
expected
=
tf
.
constant
(
[
[
1
,
1
,
0
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
-
2
,
-
2
,
2
,
3
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_width_out_of_upper_bound
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
4
,
0
,
8
,
3
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_height_out_of_lower_bound
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
1
,
-
3
,
4
,
2
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_height_out_of_upper_bound
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
1
,
1
,
1
,
0
,
0
],
[
0
,
1
,
1
,
1
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
1
,
4
,
4
,
9
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
def
test_start_out_of_upper_bound
(
self
):
expected
=
tf
.
constant
(
[
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
],
]
)
corners
=
tf
.
constant
([[
8
,
8
,
10
,
12
]],
dtype
=
tf
.
float32
)
self
.
_run_test
(
corners
,
expected
)
class
FillRectangleTest
(
tf
.
test
.
TestCase
):
def
_run_test
(
self
,
img_w
,
img_h
,
cent_x
,
cent_y
,
rec_w
,
rec_h
,
expected
):
batch_size
=
1
batch_shape
=
(
batch_size
,
img_h
,
img_w
,
1
)
images
=
tf
.
ones
(
batch_shape
,
dtype
=
tf
.
int32
)
centers_x
=
tf
.
fill
([
batch_size
],
cent_x
)
centers_y
=
tf
.
fill
([
batch_size
],
cent_y
)
width
=
tf
.
fill
([
batch_size
],
rec_w
)
height
=
tf
.
fill
([
batch_size
],
rec_h
)
fill
=
tf
.
zeros_like
(
images
)
filled_images
=
fill_utils
.
fill_rectangle
(
images
,
centers_x
,
centers_y
,
width
,
height
,
fill
)
# remove batch dimension and channel dimension
filled_images
=
filled_images
[
0
,
...,
0
]
tf
.
assert_equal
(
filled_images
,
expected
)
def
test_rectangle_position
(
self
):
img_w
,
img_h
=
8
,
8
cent_x
,
cent_y
=
4
,
3
rec_w
,
rec_h
=
5
,
3
expected
=
tf
.
constant
(
[
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
1
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
1
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
],
dtype
=
tf
.
int32
,
)
self
.
_run_test
(
img_w
,
img_h
,
cent_x
,
cent_y
,
rec_w
,
rec_h
,
expected
)
def
test_width_out_of_lower_bound
(
self
):
img_w
,
img_h
=
8
,
8
cent_x
,
cent_y
=
1
,
3
rec_w
,
rec_h
=
5
,
3
# assert width is truncated when cent_x - rec_w < 0
expected
=
tf
.
constant
(
[
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
],
dtype
=
tf
.
int32
,
)
self
.
_run_test
(
img_w
,
img_h
,
cent_x
,
cent_y
,
rec_w
,
rec_h
,
expected
)
def
test_width_out_of_upper_bound
(
self
):
img_w
,
img_h
=
8
,
8
cent_x
,
cent_y
=
6
,
3
rec_w
,
rec_h
=
5
,
3
# assert width is truncated when cent_x + rec_w > img_w
expected
=
tf
.
constant
(
[
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
],
dtype
=
tf
.
int32
,
)
self
.
_run_test
(
img_w
,
img_h
,
cent_x
,
cent_y
,
rec_w
,
rec_h
,
expected
)
def
test_height_out_of_lower_bound
(
self
):
img_w
,
img_h
=
8
,
8
cent_x
,
cent_y
=
4
,
1
rec_w
,
rec_h
=
3
,
5
# assert height is truncated when cent_y - rec_h < 0
expected
=
tf
.
constant
(
[
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
],
dtype
=
tf
.
int32
,
)
self
.
_run_test
(
img_w
,
img_h
,
cent_x
,
cent_y
,
rec_w
,
rec_h
,
expected
)
def
test_height_out_of_upper_bound
(
self
):
img_w
,
img_h
=
8
,
8
cent_x
,
cent_y
=
4
,
6
rec_w
,
rec_h
=
3
,
5
# assert height is truncated when cent_y + rec_h > img_h
expected
=
tf
.
constant
(
[
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
,
0
,
1
,
1
],
],
dtype
=
tf
.
int32
,
)
self
.
_run_test
(
img_w
,
img_h
,
cent_x
,
cent_y
,
rec_w
,
rec_h
,
expected
)
def
test_different_fill
(
self
):
batch_size
=
2
img_w
,
img_h
=
5
,
5
cent_x
,
cent_y
=
2
,
2
rec_w
,
rec_h
=
3
,
3
batch_shape
=
(
batch_size
,
img_h
,
img_w
,
1
)
images
=
tf
.
ones
(
batch_shape
,
dtype
=
tf
.
int32
)
centers_x
=
tf
.
fill
([
batch_size
],
cent_x
)
centers_y
=
tf
.
fill
([
batch_size
],
cent_y
)
width
=
tf
.
fill
([
batch_size
],
rec_w
)
height
=
tf
.
fill
([
batch_size
],
rec_h
)
fill
=
tf
.
stack
([
tf
.
fill
(
images
[
0
].
shape
,
2
),
tf
.
fill
(
images
[
1
].
shape
,
3
)])
filled_images
=
fill_utils
.
fill_rectangle
(
images
,
centers_x
,
centers_y
,
width
,
height
,
fill
)
# remove channel dimension
filled_images
=
filled_images
[...,
0
]
expected
=
tf
.
constant
(
[
[
[
1
,
1
,
1
,
1
,
1
],
[
1
,
2
,
2
,
2
,
1
],
[
1
,
2
,
2
,
2
,
1
],
[
1
,
2
,
2
,
2
,
1
],
[
1
,
1
,
1
,
1
,
1
],
],
[
[
1
,
1
,
1
,
1
,
1
],
[
1
,
3
,
3
,
3
,
1
],
[
1
,
3
,
3
,
3
,
1
],
[
1
,
3
,
3
,
3
,
1
],
[
1
,
1
,
1
,
1
,
1
],
],
],
dtype
=
tf
.
int32
,
)
tf
.
assert_equal
(
filled_images
,
expected
)
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment