Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a26d77c4
Commit
a26d77c4
authored
Oct 09, 2020
by
Zhenyu Tan
Committed by
A. Unique TensorFlower
Oct 09, 2020
Browse files
Internal change
PiperOrigin-RevId: 336353171
parent
6ac8ca60
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
289 additions
and
209 deletions
+289
-209
official/vision/beta/modeling/layers/box_matcher.py
official/vision/beta/modeling/layers/box_matcher.py
+0
-141
official/vision/beta/modeling/layers/roi_sampler.py
official/vision/beta/modeling/layers/roi_sampler.py
+11
-14
official/vision/beta/ops/anchor.py
official/vision/beta/ops/anchor.py
+6
-6
official/vision/keras_cv/ops/__init__.py
official/vision/keras_cv/ops/__init__.py
+1
-1
official/vision/keras_cv/ops/anchor_generator_test.py
official/vision/keras_cv/ops/anchor_generator_test.py
+0
-21
official/vision/keras_cv/ops/box_matcher_test.py
official/vision/keras_cv/ops/box_matcher_test.py
+78
-0
official/vision/keras_cv/ops/iou_similarity_test.py
official/vision/keras_cv/ops/iou_similarity_test.py
+76
-0
official/vision/keras_cv/ops/target_gather.py
official/vision/keras_cv/ops/target_gather.py
+40
-26
official/vision/keras_cv/ops/target_gather_test.py
official/vision/keras_cv/ops/target_gather_test.py
+77
-0
No files found.
official/vision/beta/modeling/layers/box_matcher.py
deleted
100644 → 0
View file @
6ac8ca60
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Box matcher."""
# Import libraries
import
tensorflow
as
tf
from
official.vision.beta.ops
import
box_ops
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BoxMatcher
(
tf
.
keras
.
layers
.
Layer
):
"""Match boxes with groundtruth boxes."""
def
__init__
(
self
,
foreground_iou_threshold
=
0.5
,
background_iou_high_threshold
=
0.5
,
background_iou_low_threshold
=
0
,
**
kwargs
):
"""Initializes a box matcher.
Args:
foreground_iou_threshold: float, represent the IoU threshold for a box to
be considered as positive (if >= `foreground_iou_threshold`).
background_iou_high_threshold: float, represent the IoU threshold for a
box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`]).
background_iou_low_threshold: float, represent the IoU threshold for a box
to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`])
**kwargs: other key word arguments passed to Layer.
"""
self
.
_config_dict
=
{
'foreground_iou_threshold'
:
foreground_iou_threshold
,
'background_iou_high_threshold'
:
background_iou_high_threshold
,
'background_iou_low_threshold'
:
background_iou_low_threshold
,
}
super
(
BoxMatcher
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
boxes
,
gt_boxes
,
gt_classes
):
"""Match boxes to groundtruth boxes.
Given the proposal boxes and the groundtruth boxes and classes, perform the
groundtruth matching by taking the argmax of the IoU between boxes and
groundtruth boxes.
Args:
boxes: a tensor of shape of [batch_size, N, 4] representing the box
coordianates to be matched to groundtruth boxes.
gt_boxes: a tensor of shape of [batch_size, MAX_INSTANCES, 4] representing
the groundtruth box coordinates. It is padded with -1s to indicate the
invalid boxes.
gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
classes. It is padded with -1s to indicate the invalid classes.
Returns:
matched_gt_boxes: a tensor of shape of [batch, N, 4], representing
the matched groundtruth box coordinates for each input box. The box is
considered to match to a groundtruth box only if the IoU overlap is
greater than `foreground_iou_threshold`. If the box is a negative match,
or does not overlap with any groundtruth boxes, the matched boxes will
be set to all 0s.
matched_gt_classes: a tensor of shape of [batch, N], representing
the matched groundtruth classes for each input box. If the box is a
negative match or does not overlap with any groundtruth boxes, the
matched classes of it will be set to 0, which corresponds to the
background class.
matched_gt_indices: a tensor of shape of [batch, N], representing the
indices of the matched groundtruth boxes in the original gt_boxes
tensor. If the box is a negative match or does not overlap with any
groundtruth boxes, the index of the matched groundtruth will be set to
-1.
positive_matches: a bool tensor of shape of [batch, N], representing
whether each box is a positive matches or not. A positive match is the
case where IoU of a box with any groundtruth box is greater than
`foreground_iou_threshold`.
negative_matches: a bool tensor of shape of [batch, N], representing
whether each box is a negative matches or not. A negative match is the
case where IoU of a box with any groundtruth box is greater than
`background_iou_low_threshold` and less than
`background_iou_low_threshold`.
ignored_matches: a bool tensor of shape of [batch, N], representing
whether each box is an ignored matches or not. An ignored matches is the
match that is neither positive or negative.
"""
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
,
matched_iou
,
_
=
(
box_ops
.
box_matching
(
boxes
,
gt_boxes
,
gt_classes
))
positive_matches
=
tf
.
greater
(
matched_iou
,
self
.
_config_dict
[
'foreground_iou_threshold'
])
negative_matches
=
tf
.
logical_and
(
tf
.
greater_equal
(
matched_iou
,
self
.
_config_dict
[
'background_iou_low_threshold'
]),
tf
.
less
(
matched_iou
,
self
.
_config_dict
[
'background_iou_high_threshold'
]))
ignored_matches
=
tf
.
logical_and
(
tf
.
less
(
matched_iou
,
0.0
),
tf
.
greater_equal
(
matched_iou
,
self
.
_config_dict
[
'background_iou_high_threshold'
]))
ignored_matches
=
tf
.
logical_and
(
ignored_matches
,
tf
.
less
(
matched_iou
,
self
.
_config_dict
[
'foreground_iou_threshold'
]))
background_indicator
=
tf
.
logical_or
(
negative_matches
,
ignored_matches
)
# re-assign negatively matched boxes to the background class.
matched_gt_boxes
=
tf
.
where
(
tf
.
tile
(
tf
.
expand_dims
(
background_indicator
,
-
1
),
[
1
,
1
,
4
]),
tf
.
zeros_like
(
matched_gt_boxes
),
matched_gt_boxes
)
matched_gt_classes
=
tf
.
where
(
background_indicator
,
tf
.
zeros_like
(
matched_gt_classes
),
matched_gt_classes
)
matched_gt_indices
=
tf
.
where
(
background_indicator
,
-
tf
.
ones_like
(
matched_gt_indices
),
matched_gt_indices
)
return
(
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
,
positive_matches
,
negative_matches
,
ignored_matches
)
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
official/vision/beta/modeling/layers/roi_sampler.py
View file @
a26d77c4
...
@@ -19,7 +19,6 @@ import tensorflow as tf
...
@@ -19,7 +19,6 @@ import tensorflow as tf
from
official.vision
import
keras_cv
from
official.vision
import
keras_cv
from
official.vision.beta.modeling.layers
import
box_sampler
from
official.vision.beta.modeling.layers
import
box_sampler
from
official.vision.beta.ops
import
box_ops
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
...
@@ -68,7 +67,7 @@ class ROISampler(tf.keras.layers.Layer):
...
@@ -68,7 +67,7 @@ class ROISampler(tf.keras.layers.Layer):
foreground_iou_threshold
foreground_iou_threshold
],
],
indicators
=
[
-
3
,
-
1
,
-
2
,
1
])
indicators
=
[
-
3
,
-
1
,
-
2
,
1
])
self
.
_
anchor_label
er
=
keras_cv
.
ops
.
AnchorLabel
er
()
self
.
_
target_gath
er
=
keras_cv
.
ops
.
TargetGath
er
()
self
.
_sampler
=
box_sampler
.
BoxSampler
(
self
.
_sampler
=
box_sampler
.
BoxSampler
(
num_sampled_rois
,
foreground_fraction
)
num_sampled_rois
,
foreground_fraction
)
...
@@ -130,13 +129,12 @@ class ROISampler(tf.keras.layers.Layer):
...
@@ -130,13 +129,12 @@ class ROISampler(tf.keras.layers.Layer):
background_mask
=
tf
.
expand_dims
(
background_mask
=
tf
.
expand_dims
(
tf
.
logical_or
(
negative_matches
,
invalid_matches
),
-
1
)
tf
.
logical_or
(
negative_matches
,
invalid_matches
),
-
1
)
gt_classes
=
tf
.
expand_dims
(
gt_classes
,
axis
=-
1
)
gt_classes
=
tf
.
expand_dims
(
gt_classes
,
axis
=-
1
)
matched_gt_classes
=
self
.
_
anchor_label
er
(
gt_classes
,
matched_gt_indices
,
matched_gt_classes
=
self
.
_
target_gath
er
(
gt_classes
,
matched_gt_indices
,
background_mask
)
background_mask
)
matched_gt_classes
=
tf
.
where
(
background_mask
,
matched_gt_classes
=
tf
.
where
(
background_mask
,
tf
.
zeros_like
(
matched_gt_classes
),
tf
.
zeros_like
(
matched_gt_classes
),
matched_gt_classes
)
matched_gt_classes
)
matched_gt_classes
=
tf
.
squeeze
(
matched_gt_classes
,
axis
=-
1
)
matched_gt_boxes
=
self
.
_target_gather
(
gt_boxes
,
matched_gt_indices
,
matched_gt_boxes
=
self
.
_anchor_labeler
(
gt_boxes
,
matched_gt_indices
,
tf
.
tile
(
background_mask
,
[
1
,
1
,
4
]))
tf
.
tile
(
background_mask
,
[
1
,
1
,
4
]))
matched_gt_boxes
=
tf
.
where
(
background_mask
,
matched_gt_boxes
=
tf
.
where
(
background_mask
,
tf
.
zeros_like
(
matched_gt_boxes
),
tf
.
zeros_like
(
matched_gt_boxes
),
...
@@ -148,13 +146,12 @@ class ROISampler(tf.keras.layers.Layer):
...
@@ -148,13 +146,12 @@ class ROISampler(tf.keras.layers.Layer):
sampled_indices
=
self
.
_sampler
(
sampled_indices
=
self
.
_sampler
(
positive_matches
,
negative_matches
,
ignored_matches
)
positive_matches
,
negative_matches
,
ignored_matches
)
sampled_rois
,
sampled_gt_boxes
,
sampled_gt_classes
,
sampled_gt_indices
=
(
sampled_rois
=
self
.
_target_gather
(
boxes
,
sampled_indices
)
box_ops
.
gather_instances
(
sampled_gt_boxes
=
self
.
_target_gather
(
matched_gt_boxes
,
sampled_indices
)
sampled_indices
,
sampled_gt_classes
=
tf
.
squeeze
(
self
.
_target_gather
(
boxes
,
matched_gt_classes
,
sampled_indices
),
axis
=-
1
)
matched_gt_boxes
,
sampled_gt_indices
=
tf
.
squeeze
(
self
.
_target_gather
(
matched_gt_classes
,
tf
.
expand_dims
(
matched_gt_indices
,
-
1
),
sampled_indices
),
axis
=-
1
)
matched_gt_indices
))
return
(
sampled_rois
,
sampled_gt_boxes
,
sampled_gt_classes
,
return
(
sampled_rois
,
sampled_gt_boxes
,
sampled_gt_classes
,
sampled_gt_indices
)
sampled_gt_indices
)
...
...
official/vision/beta/ops/anchor.py
View file @
a26d77c4
...
@@ -133,7 +133,7 @@ class AnchorLabeler(object):
...
@@ -133,7 +133,7 @@ class AnchorLabeler(object):
with a score below the threshold is labeled negative.
with a score below the threshold is labeled negative.
"""
"""
self
.
similarity_calc
=
keras_cv
.
ops
.
IouSimilarity
()
self
.
similarity_calc
=
keras_cv
.
ops
.
IouSimilarity
()
self
.
anchor_label
er
=
keras_cv
.
ops
.
AnchorLabel
er
()
self
.
target_gath
er
=
keras_cv
.
ops
.
TargetGath
er
()
self
.
matcher
=
keras_cv
.
ops
.
BoxMatcher
(
self
.
matcher
=
keras_cv
.
ops
.
BoxMatcher
(
thresholds
=
[
unmatched_threshold
,
match_threshold
],
thresholds
=
[
unmatched_threshold
,
match_threshold
],
indicators
=
[
-
1
,
-
2
,
1
],
indicators
=
[
-
1
,
-
2
,
1
],
...
@@ -177,13 +177,13 @@ class AnchorLabeler(object):
...
@@ -177,13 +177,13 @@ class AnchorLabeler(object):
match_indices
,
match_indicators
=
self
.
matcher
(
similarity_matrix
)
match_indices
,
match_indicators
=
self
.
matcher
(
similarity_matrix
)
mask
=
tf
.
less_equal
(
match_indicators
,
0
)
mask
=
tf
.
less_equal
(
match_indicators
,
0
)
cls_mask
=
tf
.
expand_dims
(
mask
,
-
1
)
cls_mask
=
tf
.
expand_dims
(
mask
,
-
1
)
cls_targets
=
self
.
anchor_label
er
(
gt_labels
,
match_indices
,
cls_mask
,
-
1
)
cls_targets
=
self
.
target_gath
er
(
gt_labels
,
match_indices
,
cls_mask
,
-
1
)
box_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
4
])
box_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
4
])
box_targets
=
self
.
anchor_label
er
(
gt_boxes
,
match_indices
,
box_mask
)
box_targets
=
self
.
target_gath
er
(
gt_boxes
,
match_indices
,
box_mask
)
weights
=
tf
.
squeeze
(
tf
.
ones_like
(
gt_labels
,
dtype
=
tf
.
float32
),
-
1
)
weights
=
tf
.
squeeze
(
tf
.
ones_like
(
gt_labels
,
dtype
=
tf
.
float32
),
-
1
)
box_weights
=
self
.
anchor_label
er
(
weights
,
match_indices
,
mask
)
box_weights
=
self
.
target_gath
er
(
weights
,
match_indices
,
mask
)
ignore_mask
=
tf
.
equal
(
match_indicators
,
-
2
)
ignore_mask
=
tf
.
equal
(
match_indicators
,
-
2
)
cls_weights
=
self
.
anchor_label
er
(
weights
,
match_indices
,
ignore_mask
)
cls_weights
=
self
.
target_gath
er
(
weights
,
match_indices
,
ignore_mask
)
box_targets_list
=
box_list
.
BoxList
(
box_targets
)
box_targets_list
=
box_list
.
BoxList
(
box_targets
)
anchor_box_list
=
box_list
.
BoxList
(
flattened_anchor_boxes
)
anchor_box_list
=
box_list
.
BoxList
(
flattened_anchor_boxes
)
box_targets
=
self
.
box_coder
.
encode
(
box_targets_list
,
anchor_box_list
)
box_targets
=
self
.
box_coder
.
encode
(
box_targets_list
,
anchor_box_list
)
...
@@ -279,7 +279,7 @@ class RpnAnchorLabeler(AnchorLabeler):
...
@@ -279,7 +279,7 @@ class RpnAnchorLabeler(AnchorLabeler):
match_indices
,
match_indicators
=
self
.
matcher
(
similarity_matrix
)
match_indices
,
match_indicators
=
self
.
matcher
(
similarity_matrix
)
box_mask
=
tf
.
tile
(
tf
.
expand_dims
(
tf
.
less_equal
(
match_indicators
,
0
),
-
1
),
box_mask
=
tf
.
tile
(
tf
.
expand_dims
(
tf
.
less_equal
(
match_indicators
,
0
),
-
1
),
[
1
,
4
])
[
1
,
4
])
box_targets
=
self
.
anchor_label
er
(
gt_boxes
,
match_indices
,
box_mask
)
box_targets
=
self
.
target_gath
er
(
gt_boxes
,
match_indices
,
box_mask
)
box_targets_list
=
box_list
.
BoxList
(
box_targets
)
box_targets_list
=
box_list
.
BoxList
(
box_targets
)
anchor_box_list
=
box_list
.
BoxList
(
flattened_anchor_boxes
)
anchor_box_list
=
box_list
.
BoxList
(
flattened_anchor_boxes
)
box_targets
=
self
.
box_coder
.
encode
(
box_targets_list
,
anchor_box_list
)
box_targets
=
self
.
box_coder
.
encode
(
box_targets_list
,
anchor_box_list
)
...
...
official/vision/keras_cv/ops/__init__.py
View file @
a26d77c4
...
@@ -14,6 +14,6 @@
...
@@ -14,6 +14,6 @@
# ==============================================================================
# ==============================================================================
"""Keras-CV layers package definition."""
"""Keras-CV layers package definition."""
from
official.vision.keras_cv.ops.anchor_generator
import
AnchorGenerator
from
official.vision.keras_cv.ops.anchor_generator
import
AnchorGenerator
from
official.vision.keras_cv.ops.anchor_labeler
import
AnchorLabeler
from
official.vision.keras_cv.ops.box_matcher
import
BoxMatcher
from
official.vision.keras_cv.ops.box_matcher
import
BoxMatcher
from
official.vision.keras_cv.ops.iou_similarity
import
IouSimilarity
from
official.vision.keras_cv.ops.iou_similarity
import
IouSimilarity
from
official.vision.keras_cv.ops.target_gather
import
TargetGather
official/vision/keras_cv/ops/anchor_generator_test.py
View file @
a26d77c4
...
@@ -16,8 +16,6 @@
...
@@ -16,8 +16,6 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.vision.keras_cv.ops
import
anchor_generator
from
official.vision.keras_cv.ops
import
anchor_generator
...
@@ -65,25 +63,6 @@ class AnchorGeneratorTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -65,25 +63,6 @@ class AnchorGeneratorTest(parameterized.TestCase, tf.test.TestCase):
anchors
=
anchor_gen
(
image_size
).
numpy
()
anchors
=
anchor_gen
(
image_size
).
numpy
()
self
.
assertAllClose
(
expected_boxes
,
anchors
)
self
.
assertAllClose
(
expected_boxes
,
anchors
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
strategy_combinations
.
all_strategies
))
def
testAnchorGenerationDistributed
(
self
,
distribution
):
image_size
=
[
64
,
64
]
anchor_size
=
64
stride
=
32
aspect_ratios
=
[
1.0
]
with
distribution
.
scope
():
anchor_gen
=
anchor_generator
.
_SingleAnchorGenerator
(
anchor_size
=
anchor_size
,
scales
=
[
1.
],
aspect_ratios
=
aspect_ratios
,
stride
=
stride
,
clip_boxes
=
False
)
anchors
=
anchor_gen
(
image_size
).
numpy
()
expected_boxes
=
[[[
-
16.
,
-
16.
,
48.
,
48.
],
[
-
16.
,
16.
,
48.
,
80.
]],
[[
16.
,
-
16.
,
80.
,
48.
],
[
16.
,
16.
,
80.
,
80.
]]]
self
.
assertAllClose
(
expected_boxes
,
anchors
)
class
MultiScaleAnchorGeneratorTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
MultiScaleAnchorGeneratorTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/vision/keras_cv/ops/box_matcher_test.py
0 → 100644
View file @
a26d77c4
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for box_matcher.py."""
import
tensorflow
as
tf
from
official.vision.keras_cv.ops
import
box_matcher
class
BoxMatcherTest
(
tf
.
test
.
TestCase
):
def
test_box_matcher_unbatched
(
self
):
sim_matrix
=
tf
.
constant
(
[[
0.04
,
0
,
0
,
0
],
[
0
,
0
,
1.
,
0
]],
dtype
=
tf
.
float32
)
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
matcher
=
box_matcher
.
BoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
indicators
=
[
-
3
,
-
2
,
-
1
,
1
])
match_indices
,
match_indicators
=
matcher
(
sim_matrix
)
positive_matches
=
tf
.
greater_equal
(
match_indicators
,
0
)
negative_matches
=
tf
.
equal
(
match_indicators
,
-
2
)
self
.
assertAllEqual
(
positive_matches
.
numpy
(),
[
False
,
True
])
self
.
assertAllEqual
(
negative_matches
.
numpy
(),
[
True
,
False
])
self
.
assertAllEqual
(
match_indices
.
numpy
(),
[
0
,
2
])
self
.
assertAllEqual
(
match_indicators
.
numpy
(),
[
-
2
,
1
])
def
test_box_matcher_batched
(
self
):
sim_matrix
=
tf
.
constant
(
[[[
0.04
,
0
,
0
,
0
],
[
0
,
0
,
1.
,
0
]]],
dtype
=
tf
.
float32
)
fg_threshold
=
0.5
bg_thresh_hi
=
0.2
bg_thresh_lo
=
0.0
matcher
=
box_matcher
.
BoxMatcher
(
thresholds
=
[
bg_thresh_lo
,
bg_thresh_hi
,
fg_threshold
],
indicators
=
[
-
3
,
-
2
,
-
1
,
1
])
match_indices
,
match_indicators
=
matcher
(
sim_matrix
)
positive_matches
=
tf
.
greater_equal
(
match_indicators
,
0
)
negative_matches
=
tf
.
equal
(
match_indicators
,
-
2
)
self
.
assertAllEqual
(
positive_matches
.
numpy
(),
[[
False
,
True
]])
self
.
assertAllEqual
(
negative_matches
.
numpy
(),
[[
True
,
False
]])
self
.
assertAllEqual
(
match_indices
.
numpy
(),
[[
0
,
2
]])
self
.
assertAllEqual
(
match_indicators
.
numpy
(),
[[
-
2
,
1
]])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/keras_cv/ops/iou_similarity_test.py
0 → 100644
View file @
a26d77c4
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for iou_similarity.py."""
import
tensorflow
as
tf
from
official.vision.keras_cv.ops
import
iou_similarity
class
BoxMatcherTest
(
tf
.
test
.
TestCase
):
def
test_similarity_unbatched
(
self
):
boxes
=
tf
.
constant
(
[
[
0
,
0
,
1
,
1
],
[
5
,
0
,
10
,
5
],
],
dtype
=
tf
.
float32
)
gt_boxes
=
tf
.
constant
(
[
[
0
,
0
,
5
,
5
],
[
0
,
5
,
5
,
10
],
[
5
,
0
,
10
,
5
],
[
5
,
5
,
10
,
10
],
],
dtype
=
tf
.
float32
)
sim_calc
=
iou_similarity
.
IouSimilarity
()
sim_matrix
=
sim_calc
(
boxes
,
gt_boxes
)
self
.
assertAllClose
(
sim_matrix
.
numpy
(),
[[
0.04
,
0
,
0
,
0
],
[
0
,
0
,
1.
,
0
]])
def
test_similarity_batched
(
self
):
boxes
=
tf
.
constant
(
[[
[
0
,
0
,
1
,
1
],
[
5
,
0
,
10
,
5
],
]],
dtype
=
tf
.
float32
)
gt_boxes
=
tf
.
constant
(
[[
[
0
,
0
,
5
,
5
],
[
0
,
5
,
5
,
10
],
[
5
,
0
,
10
,
5
],
[
5
,
5
,
10
,
10
],
]],
dtype
=
tf
.
float32
)
sim_calc
=
iou_similarity
.
IouSimilarity
()
sim_matrix
=
sim_calc
(
boxes
,
gt_boxes
)
self
.
assertAllClose
(
sim_matrix
.
numpy
(),
[[[
0.04
,
0
,
0
,
0
],
[
0
,
0
,
1.
,
0
]]])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/keras_cv/ops/
anchor_label
er.py
→
official/vision/keras_cv/ops/
target_gath
er.py
View file @
a26d77c4
...
@@ -12,57 +12,61 @@
...
@@ -12,57 +12,61 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Definition of
anchor labeler, which assigns ground truth boxes to anchor
s."""
"""Definition of
target gather, which gathers targets from indice
s."""
import
tensorflow
as
tf
import
tensorflow
as
tf
class
AnchorLabel
er
:
class
TargetGath
er
:
"""
Label
er for dense object detector."""
"""
Targer gath
er for dense object detector."""
def
__call__
(
self
,
labels
,
match_indices
,
mask
,
mask_val
=
0.0
):
def
__call__
(
self
,
labels
,
match_indices
,
mask
=
None
,
mask_val
=
0.0
):
"""Labels anchors with ground truth inputs.
"""Labels anchors with ground truth inputs.
B: batch_size
B: batch_size
N: number of groundtruth boxes.
N: number of groundtruth boxes.
Args:
Args:
labels: An integer tensor with shape [N,
1
] or [B, N,
1
] representing
labels: An integer tensor with shape [N,
dims
] or [B, N,
...
] representing
groundtruth labels.
groundtruth labels.
match_indices: An integer tensor with shape [
N
] or [B,
N
] representing
match_indices: An integer tensor with shape [
M
] or [B,
M
] representing
match label index.
match label index.
mask: An
integer
tensor with shape [
N
] or [B,
N
] representing
match
mask: An
boolean
tensor with shape [
M, dims
] or [B,
M,...
] representing
labels
, e.g., 1 for positive, -1 for negative, -2 for ignore
.
match
labels.
mask_val: An integer to fill in for mask.
mask_val: An integer to fill in for mask.
Returns:
Returns:
class_targets: A integer Tensor with shape [num_anchors].
target: An integer Tensor with shape [M] or [B, M]
box_targets: A float Tensor with shape [num_anchors, 4].
Raises:
class_weights: A float Tensor with shape [num_anchors], that
ValueError: If `labels` is higher than rank 3.
serves as masking / sample weight for classification loss. Its value
is 1.0 for positive and negative matched anchors, and 0.0 for ignored
anchors.
box_weights: A float Tensor with shape [num_anchors], that
serves as masking / sample weight for regression loss. Its value is
1.0 for positive matched anchors, and 0.0 for negative and ignored
anchors.
"""
"""
if
len
(
labels
.
shape
)
<=
2
:
if
len
(
labels
.
shape
)
<=
2
:
return
self
.
_gather_unbatched
(
labels
,
match_indices
,
mask
,
mask_val
)
return
self
.
_gather_unbatched
(
labels
,
match_indices
,
mask
,
mask_val
)
elif
len
(
labels
.
shape
)
==
3
:
elif
len
(
labels
.
shape
)
==
3
:
return
self
.
_gather_batched
(
labels
,
match_indices
,
mask
,
mask_val
)
return
self
.
_gather_batched
(
labels
,
match_indices
,
mask
,
mask_val
)
else
:
raise
ValueError
(
"`TargetGather` does not support `labels` with rank "
"larger than 3, got {}"
.
format
(
len
(
labels
.
shape
)))
def
_gather_unbatched
(
self
,
labels
,
match_indices
,
mask
,
mask_val
):
def
_gather_unbatched
(
self
,
labels
,
match_indices
,
mask
,
mask_val
):
"""Gather based on unbatched labels and boxes."""
"""Gather based on unbatched labels and boxes."""
num_gt_boxes
=
tf
.
shape
(
labels
)[
0
]
num_gt_boxes
=
tf
.
shape
(
labels
)[
0
]
masked_targets
=
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones_like
(
mask
,
dtype
=
labels
.
dtype
)
def
_assign_when_rows_empty
():
def
_assign_when_rows_empty
():
return
masked_targets
if
len
(
labels
.
shape
)
>
1
:
mask_shape
=
[
match_indices
.
shape
[
0
],
labels
.
shape
[
-
1
]]
else
:
mask_shape
=
[
match_indices
.
shape
[
0
]]
return
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones
(
mask_shape
,
dtype
=
labels
.
dtype
)
def
_assign_when_rows_not_empty
():
def
_assign_when_rows_not_empty
():
targets
=
tf
.
gather
(
labels
,
match_indices
)
targets
=
tf
.
gather
(
labels
,
match_indices
)
if
mask
is
None
:
return
targets
else
:
masked_targets
=
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones_like
(
mask
,
dtype
=
labels
.
dtype
)
return
tf
.
where
(
mask
,
masked_targets
,
targets
)
return
tf
.
where
(
mask
,
masked_targets
,
targets
)
return
tf
.
cond
(
tf
.
greater
(
num_gt_boxes
,
0
),
return
tf
.
cond
(
tf
.
greater
(
num_gt_boxes
,
0
),
...
@@ -73,9 +77,14 @@ class AnchorLabeler:
...
@@ -73,9 +77,14 @@ class AnchorLabeler:
"""Gather based on batched labels."""
"""Gather based on batched labels."""
batch_size
=
labels
.
shape
[
0
]
batch_size
=
labels
.
shape
[
0
]
if
batch_size
==
1
:
if
batch_size
==
1
:
if
mask
is
not
None
:
result
=
self
.
_gather_unbatched
(
result
=
self
.
_gather_unbatched
(
tf
.
squeeze
(
labels
,
axis
=
0
),
tf
.
squeeze
(
match_indices
,
axis
=
0
),
tf
.
squeeze
(
labels
,
axis
=
0
),
tf
.
squeeze
(
match_indices
,
axis
=
0
),
tf
.
squeeze
(
mask
,
axis
=
0
),
mask_val
)
tf
.
squeeze
(
mask
,
axis
=
0
),
mask_val
)
else
:
result
=
self
.
_gather_unbatched
(
tf
.
squeeze
(
labels
,
axis
=
0
),
tf
.
squeeze
(
match_indices
,
axis
=
0
),
None
,
mask_val
)
return
tf
.
expand_dims
(
result
,
axis
=
0
)
return
tf
.
expand_dims
(
result
,
axis
=
0
)
else
:
else
:
indices_shape
=
tf
.
shape
(
match_indices
)
indices_shape
=
tf
.
shape
(
match_indices
)
...
@@ -86,4 +95,9 @@ class AnchorLabeler:
...
@@ -86,4 +95,9 @@ class AnchorLabeler:
gather_nd_indices
=
tf
.
stack
(
gather_nd_indices
=
tf
.
stack
(
[
batch_indices
,
match_indices
],
axis
=-
1
)
[
batch_indices
,
match_indices
],
axis
=-
1
)
targets
=
tf
.
gather_nd
(
labels
,
gather_nd_indices
)
targets
=
tf
.
gather_nd
(
labels
,
gather_nd_indices
)
if
mask
is
None
:
return
targets
return
targets
else
:
masked_targets
=
tf
.
cast
(
mask_val
,
labels
.
dtype
)
*
tf
.
ones_like
(
mask
,
dtype
=
labels
.
dtype
)
return
tf
.
where
(
mask
,
masked_targets
,
targets
)
official/vision/keras_cv/ops/target_gather_test.py
0 → 100644
View file @
a26d77c4
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for target_gather.py."""
import
tensorflow
as
tf
from
official.vision.keras_cv.ops
import
target_gather
class
TargetGatherTest
(
tf
.
test
.
TestCase
):
def
test_target_gather_batched
(
self
):
gt_boxes
=
tf
.
constant
(
[[
[
0
,
0
,
5
,
5
],
[
0
,
5
,
5
,
10
],
[
5
,
0
,
10
,
5
],
[
5
,
5
,
10
,
10
],
]],
dtype
=
tf
.
float32
)
gt_classes
=
tf
.
constant
([[[
2
],
[
10
],
[
3
],
[
-
1
]]],
dtype
=
tf
.
int32
)
labeler
=
target_gather
.
TargetGather
()
match_indices
=
tf
.
constant
([[
0
,
2
]],
dtype
=
tf
.
int32
)
match_indicators
=
tf
.
constant
([[
-
2
,
1
]])
mask
=
tf
.
less_equal
(
match_indicators
,
0
)
cls_mask
=
tf
.
expand_dims
(
mask
,
-
1
)
matched_gt_classes
=
labeler
(
gt_classes
,
match_indices
,
cls_mask
)
box_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
1
,
4
])
matched_gt_boxes
=
labeler
(
gt_boxes
,
match_indices
,
box_mask
)
self
.
assertAllEqual
(
matched_gt_classes
.
numpy
(),
[[[
0
],
[
3
]]])
self
.
assertAllClose
(
matched_gt_boxes
.
numpy
(),
[[[
0
,
0
,
0
,
0
],
[
5
,
0
,
10
,
5
]]])
def
test_target_gather_unbatched
(
self
):
gt_boxes
=
tf
.
constant
(
[
[
0
,
0
,
5
,
5
],
[
0
,
5
,
5
,
10
],
[
5
,
0
,
10
,
5
],
[
5
,
5
,
10
,
10
],
],
dtype
=
tf
.
float32
)
gt_classes
=
tf
.
constant
([[
2
],
[
10
],
[
3
],
[
-
1
]],
dtype
=
tf
.
int32
)
labeler
=
target_gather
.
TargetGather
()
match_indices
=
tf
.
constant
([
0
,
2
],
dtype
=
tf
.
int32
)
match_indicators
=
tf
.
constant
([
-
2
,
1
])
mask
=
tf
.
less_equal
(
match_indicators
,
0
)
cls_mask
=
tf
.
expand_dims
(
mask
,
-
1
)
matched_gt_classes
=
labeler
(
gt_classes
,
match_indices
,
cls_mask
)
box_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
4
])
matched_gt_boxes
=
labeler
(
gt_boxes
,
match_indices
,
box_mask
)
self
.
assertAllEqual
(
matched_gt_classes
.
numpy
(),
[[
0
],
[
3
]])
self
.
assertAllClose
(
matched_gt_boxes
.
numpy
(),
[[
0
,
0
,
0
,
0
],
[
5
,
0
,
10
,
5
]])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment