Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e1c78a72
Commit
e1c78a72
authored
Oct 02, 2020
by
Zhenyu Tan
Committed by
A. Unique TensorFlower
Oct 02, 2020
Browse files
Internal change
PiperOrigin-RevId: 335079988
parent
44a5367a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
87 deletions
+92
-87
official/vision/beta/modeling/layers/roi_sampler.py
official/vision/beta/modeling/layers/roi_sampler.py
+34
-8
official/vision/beta/ops/anchor.py
official/vision/beta/ops/anchor.py
+2
-2
official/vision/keras_cv/ops/box_matcher.py
official/vision/keras_cv/ops/box_matcher.py
+56
-77
No files found.
official/vision/beta/modeling/layers/roi_sampler.py
View file @
e1c78a72
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision
.beta.modeling.layers
import
box_matcher
from
official.vision
import
keras_cv
from
official.vision.beta.modeling.layers
import
box_sampler
from
official.vision.beta.modeling.layers
import
box_sampler
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
box_ops
...
@@ -60,10 +60,15 @@ class ROISampler(tf.keras.layers.Layer):
...
@@ -60,10 +60,15 @@ class ROISampler(tf.keras.layers.Layer):
'background_iou_high_threshold'
:
background_iou_high_threshold
,
'background_iou_high_threshold'
:
background_iou_high_threshold
,
'background_iou_low_threshold'
:
background_iou_low_threshold
,
'background_iou_low_threshold'
:
background_iou_low_threshold
,
}
}
self
.
_matcher
=
box_matcher
.
BoxMatcher
(
foreground_iou_threshold
,
self
.
_box_matcher
=
keras_cv
.
ops
.
BoxMatcher
(
background_iou_high_threshold
,
thresholds
=
[
background_iou_low_threshold
)
background_iou_low_threshold
,
background_iou_high_threshold
,
foreground_iou_threshold
],
indicators
=
[
-
3
,
-
1
,
-
2
,
1
])
self
.
_anchor_labeler
=
keras_cv
.
ops
.
AnchorLabeler
()
self
.
_sampler
=
box_sampler
.
BoxSampler
(
self
.
_sampler
=
box_sampler
.
BoxSampler
(
num_sampled_rois
,
foreground_fraction
)
num_sampled_rois
,
foreground_fraction
)
super
(
ROISampler
,
self
).
__init__
(
**
kwargs
)
super
(
ROISampler
,
self
).
__init__
(
**
kwargs
)
...
@@ -109,9 +114,30 @@ class ROISampler(tf.keras.layers.Layer):
...
@@ -109,9 +114,30 @@ class ROISampler(tf.keras.layers.Layer):
gt_boxes
=
tf
.
cast
(
gt_boxes
,
dtype
=
boxes
.
dtype
)
gt_boxes
=
tf
.
cast
(
gt_boxes
,
dtype
=
boxes
.
dtype
)
boxes
=
tf
.
concat
([
boxes
,
gt_boxes
],
axis
=
1
)
boxes
=
tf
.
concat
([
boxes
,
gt_boxes
],
axis
=
1
)
(
matched_gt_boxes
,
matched_gt_classes
,
matched_gt_indices
,
similarity_matrix
=
box_ops
.
bbox_overlap
(
boxes
,
gt_boxes
)
positive_matches
,
negative_matches
,
ignored_matches
)
=
(
matched_gt_indices
,
match_indicators
=
self
.
_box_matcher
(
similarity_matrix
)
self
.
_matcher
(
boxes
,
gt_boxes
,
gt_classes
))
positive_matches
=
tf
.
greater_equal
(
match_indicators
,
0
)
negative_matches
=
tf
.
equal
(
match_indicators
,
-
1
)
ignored_matches
=
tf
.
equal
(
match_indicators
,
-
2
)
invalid_matches
=
tf
.
equal
(
match_indicators
,
-
3
)
background_mask
=
tf
.
expand_dims
(
tf
.
logical_or
(
negative_matches
,
invalid_matches
),
-
1
)
gt_classes
=
tf
.
expand_dims
(
gt_classes
,
axis
=-
1
)
matched_gt_classes
=
self
.
_anchor_labeler
(
gt_classes
,
matched_gt_indices
,
background_mask
)
matched_gt_classes
=
tf
.
where
(
background_mask
,
tf
.
zeros_like
(
matched_gt_classes
),
matched_gt_classes
)
matched_gt_classes
=
tf
.
squeeze
(
matched_gt_classes
,
axis
=-
1
)
matched_gt_boxes
=
self
.
_anchor_labeler
(
gt_boxes
,
matched_gt_indices
,
tf
.
tile
(
background_mask
,
[
1
,
1
,
4
]))
matched_gt_boxes
=
tf
.
where
(
background_mask
,
tf
.
zeros_like
(
matched_gt_boxes
),
matched_gt_boxes
)
matched_gt_indices
=
tf
.
where
(
tf
.
squeeze
(
background_mask
,
-
1
),
-
tf
.
ones_like
(
matched_gt_indices
),
matched_gt_indices
)
sampled_indices
=
self
.
_sampler
(
sampled_indices
=
self
.
_sampler
(
positive_matches
,
negative_matches
,
ignored_matches
)
positive_matches
,
negative_matches
,
ignored_matches
)
...
...
official/vision/beta/ops/anchor.py
View file @
e1c78a72
...
@@ -135,8 +135,8 @@ class AnchorLabeler(object):
...
@@ -135,8 +135,8 @@ class AnchorLabeler(object):
self
.
similarity_calc
=
keras_cv
.
ops
.
IouSimilarity
()
self
.
similarity_calc
=
keras_cv
.
ops
.
IouSimilarity
()
self
.
anchor_labeler
=
keras_cv
.
ops
.
AnchorLabeler
()
self
.
anchor_labeler
=
keras_cv
.
ops
.
AnchorLabeler
()
self
.
matcher
=
keras_cv
.
ops
.
BoxMatcher
(
self
.
matcher
=
keras_cv
.
ops
.
BoxMatcher
(
positive
_threshold
=
match_threshold
,
thresholds
=
[
unmatched
_threshold
,
match_threshold
]
,
negative_threshold
=
unmatched_threshold
,
indicators
=
[
-
1
,
-
2
,
1
]
,
force_match_for_each_col
=
True
)
force_match_for_each_col
=
True
)
self
.
box_coder
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
self
.
box_coder
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
...
...
official/vision/keras_cv/ops/box_matcher.py
View file @
e1c78a72
...
@@ -28,60 +28,51 @@ class BoxMatcher:
...
@@ -28,60 +28,51 @@ class BoxMatcher:
To support object detection target assignment this class enables setting both
To support object detection target assignment this class enables setting both
positive_threshold (upper threshold) and negative_threshold (lower thresholds)
positive_threshold (upper threshold) and negative_threshold (lower thresholds)
defining three categories of similarity which define whether examples are
defining three categories of similarity which define whether examples are
positive, negative, or ignored:
positive, negative, or ignored, for example:
(1) similarity >= positive_threshold: Highest similarity. Matched/Positive!
(1) thresholds=[negative_threshold, positive_threshold], and
(2) positive_threshold > similarity >= negative_threshold: Medium similarity.
indicators=[negative_value, ignore_value, positive_value]: The similarity
This is Ignored.
metrics below negative_threshold will be assigned with negative_value,
(3) negative_threshold > similarity: Lowest similarity for Negative Match.
the metrics between negative_threshold and positive_threshold will be
For ignored matches this class sets the values in the Match object to -2.
assigned ignore_value, and the metrics above positive_threshold will be
assigned positive_value.
(2) thresholds=[negative_threshold, positive_threshold], and
indicators=[ignore_value, negative_value, positive_value]: The similarity
metric below negative_threshold will be assigned with ignore_value,
the metrics between negative_threshold and positive_threshold will be
assigned negative_value, and the metrics above positive_threshold will be
assigned positive_value.
"""
"""
def
__init__
(
def
__init__
(
self
,
thresholds
,
indicators
,
force_match_for_each_col
=
False
):
self
,
positive_threshold
,
negative_threshold
=
None
,
force_match_for_each_col
=
False
,
negative_lower_than_ignore
=
True
,
positive_value
=
1
,
negative_value
=-
1
,
ignore_value
=-
2
):
"""Construct BoxMatcher.
"""Construct BoxMatcher.
Args:
Args:
positive_threshold: Threshold for positive matches. Positive if
thresholds: A list of thresholds to classify boxes into
sim >= positive_threshold, where sim is the maximum value of the
different buckets. The list needs to be sorted, and will be prepended
similarity matrix for a given column. Set to None for no threshold.
with -Inf and appended with +Inf.
negative_threshold: Threshold for negative matches. Negative if
indicators: A list of values to assign for each bucket. len(`indicators`)
sim < negative_threshold or
must equal to len(`thresholds`) + 1.
positive_threshold > sim >= negative_threshold.
Defaults to positive_threshold when set to None.
force_match_for_each_col: If True, ensures that each column is matched to
force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the
at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False.
positive_threshold is high). Defaults to False. If True, all force
negative_lower_than_ignore: If True, the threshold is
matched row will be assigned to `indicators[-1]`.
positive|ignore|negative, else positive|negative|ignore. Defaults to
True.
positive_value: An integer to fill for positive match labels.
negative_value: An integer to fill for negative match labels.
ignore_value: An integer to fill for ignored match labels.
Raises:
Raises:
ValueError: If negative_threshold > positive_threshold.
ValueError: If `threshold` not sorted,
or len(indicators) != len(threshold) + 1
"""
"""
self
.
_positive_threshold
=
positive_threshold
if
not
all
([
lo
<=
hi
for
(
lo
,
hi
)
in
zip
(
thresholds
[:
-
1
],
thresholds
[
1
:])]):
if
negative_threshold
is
None
:
raise
ValueError
(
'`threshold` must be sorted, got {}'
.
format
(
thresholds
))
self
.
_negative_threshold
=
positive_threshold
self
.
indicators
=
indicators
else
:
if
len
(
indicators
)
!=
len
(
thresholds
)
+
1
:
if
negative_threshold
>
positive_threshold
:
raise
ValueError
(
'len(`indicators`) must be len(`thresholds`) + 1, got '
raise
ValueError
(
'negative_threshold needs to be smaller or equal'
'indicators {}, thresholds {}'
.
format
(
'to positive_threshold'
)
indicators
,
thresholds
))
self
.
_negative_threshold
=
negative_threshold
thresholds
=
thresholds
[:]
thresholds
.
insert
(
0
,
-
float
(
'inf'
))
self
.
_positive_value
=
positive_value
thresholds
.
append
(
float
(
'inf'
))
self
.
_negative_value
=
negative_value
self
.
thresholds
=
thresholds
self
.
_ignore_value
=
ignore_value
self
.
_force_match_for_each_col
=
force_match_for_each_col
self
.
_force_match_for_each_col
=
force_match_for_each_col
self
.
_negative_lower_than_ignore
=
negative_lower_than_ignore
def
__call__
(
self
,
similarity_matrix
):
def
__call__
(
self
,
similarity_matrix
):
"""Tries to match each column of the similarity matrix to a row.
"""Tries to match each column of the similarity matrix to a row.
...
@@ -117,8 +108,7 @@ class BoxMatcher:
...
@@ -117,8 +108,7 @@ class BoxMatcher:
"""
"""
with
tf
.
name_scope
(
'empty_gt_boxes'
):
with
tf
.
name_scope
(
'empty_gt_boxes'
):
matches
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
matches
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
match_labels
=
self
.
_negative_value
*
tf
.
ones
(
match_labels
=
-
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
[
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
return
matches
,
match_labels
return
matches
,
match_labels
def
_match_when_rows_are_non_empty
():
def
_match_when_rows_are_non_empty
():
...
@@ -133,28 +123,18 @@ class BoxMatcher:
...
@@ -133,28 +123,18 @@ class BoxMatcher:
# Get logical indices of ignored and unmatched columns as tf.int64
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals
=
tf
.
reduce_max
(
similarity_matrix
,
axis
=-
1
)
matched_vals
=
tf
.
reduce_max
(
similarity_matrix
,
axis
=-
1
)
matched_labels
=
self
.
_positive_value
*
tf
.
ones
(
matched_indicators
=
tf
.
zeros
([
batch_size
,
num_rows
],
tf
.
int32
)
[
batch_size
,
num_rows
],
tf
.
int32
)
match_dtype
=
matched_vals
.
dtype
positive_threshold
=
tf
.
cast
(
for
(
ind
,
low
,
high
)
in
zip
(
self
.
indicators
,
self
.
thresholds
[:
-
1
],
self
.
_positive_threshold
,
matched_vals
.
dtype
)
self
.
thresholds
[
1
:]):
negative_threshold
=
tf
.
cast
(
low_threshold
=
tf
.
cast
(
low
,
match_dtype
)
self
.
_negative_threshold
,
matched_vals
.
dtype
)
high_threshold
=
tf
.
cast
(
high
,
match_dtype
)
below_negative_threshold
=
tf
.
greater
(
negative_threshold
,
matched_vals
)
mask
=
tf
.
logical_and
(
between_thresholds
=
tf
.
logical_and
(
tf
.
greater_equal
(
matched_vals
,
low_threshold
),
tf
.
greater_equal
(
matched_vals
,
negative_threshold
),
tf
.
less
(
matched_vals
,
high_threshold
))
tf
.
greater
(
positive_threshold
,
matched_vals
))
matched_indicators
=
self
.
_set_values_using_indicator
(
matched_indicators
,
mask
,
ind
)
if
self
.
_negative_lower_than_ignore
:
matched_labels
=
self
.
_set_values_using_indicator
(
matched_labels
,
below_negative_threshold
,
self
.
_negative_value
)
matched_labels
=
self
.
_set_values_using_indicator
(
matched_labels
,
between_thresholds
,
self
.
_ignore_value
)
else
:
matched_labels
=
self
.
_set_values_using_indicator
(
matched_labels
,
below_negative_threshold
,
self
.
_ignore_value
)
matched_labels
=
self
.
_set_values_using_indicator
(
matched_labels
,
between_thresholds
,
self
.
_negative_value
)
if
self
.
_force_match_for_each_col
:
if
self
.
_force_match_for_each_col
:
# [batch_size, M], for each col (groundtruth_box), find the best
# [batch_size, M], for each col (groundtruth_box), find the best
...
@@ -175,27 +155,26 @@ class BoxMatcher:
...
@@ -175,27 +155,26 @@ class BoxMatcher:
# [batch_size, N]
# [batch_size, N]
final_matches
=
tf
.
where
(
force_match_column_mask
,
force_match_row_ids
,
final_matches
=
tf
.
where
(
force_match_column_mask
,
force_match_row_ids
,
matches
)
matches
)
final_matched_labels
=
tf
.
where
(
final_matched_indicators
=
tf
.
where
(
force_match_column_mask
,
force_match_column_mask
,
self
.
indicators
[
-
1
]
*
self
.
_positive_value
*
tf
.
ones
(
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
[
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
matched_indicators
)
matched_labels
)
return
final_matches
,
final_matched_indicators
return
final_matches
,
final_matched_labels
else
:
else
:
return
matches
,
matched_
label
s
return
matches
,
matched_
indicator
s
num_gt_boxes
=
similarity_matrix
.
shape
.
as_list
()[
-
1
]
or
tf
.
shape
(
num_gt_boxes
=
similarity_matrix
.
shape
.
as_list
()[
-
1
]
or
tf
.
shape
(
similarity_matrix
)[
-
1
]
similarity_matrix
)[
-
1
]
result_match
,
result_match
_label
s
=
tf
.
cond
(
result_match
,
result_match
ed_indicator
s
=
tf
.
cond
(
pred
=
tf
.
greater
(
num_gt_boxes
,
0
),
pred
=
tf
.
greater
(
num_gt_boxes
,
0
),
true_fn
=
_match_when_rows_are_non_empty
,
true_fn
=
_match_when_rows_are_non_empty
,
false_fn
=
_match_when_rows_are_empty
)
false_fn
=
_match_when_rows_are_empty
)
if
squeeze_result
:
if
squeeze_result
:
result_match
=
tf
.
squeeze
(
result_match
,
axis
=
0
)
result_match
=
tf
.
squeeze
(
result_match
,
axis
=
0
)
result_match
_label
s
=
tf
.
squeeze
(
result_match
_label
s
,
axis
=
0
)
result_match
ed_indicator
s
=
tf
.
squeeze
(
result_match
ed_indicator
s
,
axis
=
0
)
return
result_match
,
result_match
_label
s
return
result_match
,
result_match
ed_indicator
s
def
_set_values_using_indicator
(
self
,
x
,
indicator
,
val
):
def
_set_values_using_indicator
(
self
,
x
,
indicator
,
val
):
"""Set the indicated fields of x to val.
"""Set the indicated fields of x to val.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment