Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fd67924c
Commit
fd67924c
authored
May 10, 2021
by
Xianzhi Du
Committed by
A. Unique TensorFlower
May 10, 2021
Browse files
Internal change
PiperOrigin-RevId: 373004044
parent
6f50b49a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
95 additions
and
20 deletions
+95
-20
official/vision/beta/dataloaders/retinanet_input.py
official/vision/beta/dataloaders/retinanet_input.py
+25
-4
official/vision/beta/dataloaders/utils.py
official/vision/beta/dataloaders/utils.py
+4
-0
official/vision/beta/dataloaders/utils_test.py
official/vision/beta/dataloaders/utils_test.py
+13
-5
official/vision/beta/modeling/retinanet_model.py
official/vision/beta/modeling/retinanet_model.py
+2
-2
official/vision/beta/modeling/retinanet_model_test.py
official/vision/beta/modeling/retinanet_model_test.py
+1
-1
official/vision/beta/ops/anchor.py
official/vision/beta/ops/anchor.py
+26
-2
official/vision/beta/ops/anchor_test.py
official/vision/beta/ops/anchor_test.py
+24
-6
No files found.
official/vision/beta/dataloaders/retinanet_input.py
View file @
fd67924c
...
@@ -119,7 +119,12 @@ class Parser(parser.Parser):
...
@@ -119,7 +119,12 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation."""
"""Parses data for training and evaluation."""
classes
=
data
[
'groundtruth_classes'
]
classes
=
data
[
'groundtruth_classes'
]
boxes
=
data
[
'groundtruth_boxes'
]
boxes
=
data
[
'groundtruth_boxes'
]
# If not empty, `attributes` is a dict of (name, ground_truth) pairs.
# `ground_gruth` of attributes is assumed in shape [N, attribute_size].
# TODO(xianzhi): support parsing attributes weights.
attributes
=
data
.
get
(
'groundtruth_attributes'
,
{})
is_crowds
=
data
[
'groundtruth_is_crowd'
]
is_crowds
=
data
[
'groundtruth_is_crowd'
]
# Skips annotations with `is_crowd` = True.
# Skips annotations with `is_crowd` = True.
if
self
.
_skip_crowd_during_training
:
if
self
.
_skip_crowd_during_training
:
num_groundtrtuhs
=
tf
.
shape
(
input
=
classes
)[
0
]
num_groundtrtuhs
=
tf
.
shape
(
input
=
classes
)[
0
]
...
@@ -130,6 +135,8 @@ class Parser(parser.Parser):
...
@@ -130,6 +135,8 @@ class Parser(parser.Parser):
false_fn
=
lambda
:
tf
.
cast
(
tf
.
range
(
num_groundtrtuhs
),
tf
.
int64
))
false_fn
=
lambda
:
tf
.
cast
(
tf
.
range
(
num_groundtrtuhs
),
tf
.
int64
))
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
for
k
,
v
in
attributes
.
items
():
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
# Gets original image and its size.
# Gets original image and its size.
image
=
data
[
'image'
]
image
=
data
[
'image'
]
...
@@ -165,6 +172,8 @@ class Parser(parser.Parser):
...
@@ -165,6 +172,8 @@ class Parser(parser.Parser):
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
for
k
,
v
in
attributes
.
items
():
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
# Assigns anchors.
# Assigns anchors.
input_anchor
=
anchor
.
build_anchor_generator
(
input_anchor
=
anchor
.
build_anchor_generator
(
...
@@ -176,9 +185,9 @@ class Parser(parser.Parser):
...
@@ -176,9 +185,9 @@ class Parser(parser.Parser):
anchor_boxes
=
input_anchor
(
image_size
=
(
image_height
,
image_width
))
anchor_boxes
=
input_anchor
(
image_size
=
(
image_height
,
image_width
))
anchor_labeler
=
anchor
.
AnchorLabeler
(
self
.
_match_threshold
,
anchor_labeler
=
anchor
.
AnchorLabeler
(
self
.
_match_threshold
,
self
.
_unmatched_threshold
)
self
.
_unmatched_threshold
)
(
cls_targets
,
box_targets
,
cls_weights
,
(
cls_targets
,
box_targets
,
att_targets
,
cls_weights
,
box_weights
)
=
anchor_labeler
.
label_anchors
(
box_weights
)
=
anchor_labeler
.
label_anchors
(
anchor_boxes
,
boxes
,
tf
.
expand_dims
(
classes
,
axis
=
1
))
anchor_boxes
,
boxes
,
tf
.
expand_dims
(
classes
,
axis
=
1
)
,
attributes
)
# Casts input image to desired data type.
# Casts input image to desired data type.
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
...
@@ -192,6 +201,8 @@ class Parser(parser.Parser):
...
@@ -192,6 +201,8 @@ class Parser(parser.Parser):
'box_weights'
:
box_weights
,
'box_weights'
:
box_weights
,
'image_info'
:
image_info
,
'image_info'
:
image_info
,
}
}
if
att_targets
:
labels
[
'attribute_targets'
]
=
att_targets
return
image
,
labels
return
image
,
labels
def
_parse_eval_data
(
self
,
data
):
def
_parse_eval_data
(
self
,
data
):
...
@@ -199,6 +210,10 @@ class Parser(parser.Parser):
...
@@ -199,6 +210,10 @@ class Parser(parser.Parser):
groundtruths
=
{}
groundtruths
=
{}
classes
=
data
[
'groundtruth_classes'
]
classes
=
data
[
'groundtruth_classes'
]
boxes
=
data
[
'groundtruth_boxes'
]
boxes
=
data
[
'groundtruth_boxes'
]
# If not empty, `attributes` is a dict of (name, ground_truth) pairs.
# `ground_gruth` of attributes is assumed in shape [N, attribute_size].
# TODO(xianzhi): support parsing attributes weights.
attributes
=
data
.
get
(
'groundtruth_attributes'
,
{})
# Gets original image and its size.
# Gets original image and its size.
image
=
data
[
'image'
]
image
=
data
[
'image'
]
...
@@ -229,6 +244,8 @@ class Parser(parser.Parser):
...
@@ -229,6 +244,8 @@ class Parser(parser.Parser):
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
for
k
,
v
in
attributes
.
items
():
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
# Assigns anchors.
# Assigns anchors.
input_anchor
=
anchor
.
build_anchor_generator
(
input_anchor
=
anchor
.
build_anchor_generator
(
...
@@ -240,9 +257,9 @@ class Parser(parser.Parser):
...
@@ -240,9 +257,9 @@ class Parser(parser.Parser):
anchor_boxes
=
input_anchor
(
image_size
=
(
image_height
,
image_width
))
anchor_boxes
=
input_anchor
(
image_size
=
(
image_height
,
image_width
))
anchor_labeler
=
anchor
.
AnchorLabeler
(
self
.
_match_threshold
,
anchor_labeler
=
anchor
.
AnchorLabeler
(
self
.
_match_threshold
,
self
.
_unmatched_threshold
)
self
.
_unmatched_threshold
)
(
cls_targets
,
box_targets
,
cls_weights
,
(
cls_targets
,
box_targets
,
att_targets
,
cls_weights
,
box_weights
)
=
anchor_labeler
.
label_anchors
(
box_weights
)
=
anchor_labeler
.
label_anchors
(
anchor_boxes
,
boxes
,
tf
.
expand_dims
(
classes
,
axis
=
1
))
anchor_boxes
,
boxes
,
tf
.
expand_dims
(
classes
,
axis
=
1
)
,
attributes
)
# Casts input image to desired data type.
# Casts input image to desired data type.
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
...
@@ -260,6 +277,8 @@ class Parser(parser.Parser):
...
@@ -260,6 +277,8 @@ class Parser(parser.Parser):
'areas'
:
data
[
'groundtruth_area'
],
'areas'
:
data
[
'groundtruth_area'
],
'is_crowds'
:
tf
.
cast
(
data
[
'groundtruth_is_crowd'
],
tf
.
int32
),
'is_crowds'
:
tf
.
cast
(
data
[
'groundtruth_is_crowd'
],
tf
.
int32
),
}
}
if
'groundtruth_attributes'
in
data
:
groundtruths
[
'attributes'
]
=
data
[
'groundtruth_attributes'
]
groundtruths
[
'source_id'
]
=
utils
.
process_source_id
(
groundtruths
[
'source_id'
]
=
utils
.
process_source_id
(
groundtruths
[
'source_id'
])
groundtruths
[
'source_id'
])
groundtruths
=
utils
.
pad_groundtruths_to_fixed_size
(
groundtruths
=
utils
.
pad_groundtruths_to_fixed_size
(
...
@@ -275,4 +294,6 @@ class Parser(parser.Parser):
...
@@ -275,4 +294,6 @@ class Parser(parser.Parser):
'image_info'
:
image_info
,
'image_info'
:
image_info
,
'groundtruths'
:
groundtruths
,
'groundtruths'
:
groundtruths
,
}
}
if
att_targets
:
labels
[
'attribute_targets'
]
=
att_targets
return
image
,
labels
return
image
,
labels
official/vision/beta/dataloaders/utils.py
View file @
fd67924c
...
@@ -62,4 +62,8 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor],
...
@@ -62,4 +62,8 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor],
groundtruths
[
'areas'
],
size
,
-
1
)
groundtruths
[
'areas'
],
size
,
-
1
)
groundtruths
[
'classes'
]
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
groundtruths
[
'classes'
]
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
groundtruths
[
'classes'
],
size
,
-
1
)
groundtruths
[
'classes'
],
size
,
-
1
)
if
'attributes'
in
groundtruths
:
for
k
,
v
in
groundtruths
[
'attributes'
].
items
():
groundtruths
[
'attributes'
][
k
]
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
v
,
size
,
-
1
)
return
groundtruths
return
groundtruths
official/vision/beta/dataloaders/utils_test.py
View file @
fd67924c
...
@@ -40,23 +40,31 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -40,23 +40,31 @@ class UtilsTest(tf.test.TestCase, parameterized.TestCase):
utils
.
process_source_id
(
source_id
=
source_id
))
utils
.
process_source_id
(
source_id
=
source_id
))
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
([[
10
,
20
,
30
,
40
]],
[[
100
]],
[[
0
]],
10
),
([[
10
,
20
,
30
,
40
]],
[[
100
]],
[[
0
]],
10
,
None
),
([[
0.1
,
0.2
,
0.5
,
0.6
]],
[[
0.5
]],
[[
1
]],
2
),
([[
0.1
,
0.2
,
0.5
,
0.6
]],
[[
0.5
]],
[[
1
]],
2
,
[[
1.0
,
2.0
]]
),
)
)
def
test_pad_groundtruths_to_fixed_size
(
self
,
boxes
,
area
,
classes
,
size
):
def
test_pad_groundtruths_to_fixed_size
(
self
,
boxes
,
area
,
classes
,
size
,
attributes
):
groundtruths
=
{}
groundtruths
=
{}
groundtruths
[
'boxes'
]
=
tf
.
constant
(
boxes
)
groundtruths
[
'boxes'
]
=
tf
.
constant
(
boxes
)
groundtruths
[
'is_crowds'
]
=
tf
.
constant
([[
0
]])
groundtruths
[
'is_crowds'
]
=
tf
.
constant
([[
0
]])
groundtruths
[
'areas'
]
=
tf
.
constant
(
area
)
groundtruths
[
'areas'
]
=
tf
.
constant
(
area
)
groundtruths
[
'classes'
]
=
tf
.
constant
(
classes
)
groundtruths
[
'classes'
]
=
tf
.
constant
(
classes
)
if
attributes
:
groundtruths
[
'attributes'
]
=
{
'depth'
:
tf
.
constant
(
attributes
)}
actual_result
=
utils
.
pad_groundtruths_to_fixed_size
(
actual_result
=
utils
.
pad_groundtruths_to_fixed_size
(
groundtruths
=
groundtruths
,
size
=
size
)
groundtruths
=
groundtruths
,
size
=
size
)
# Check that the first dimension is padded to the expected size.
# Check that the first dimension is padded to the expected size.
for
key
in
actual_result
:
for
key
in
actual_result
:
pad_shape
=
actual_result
[
key
].
shape
[
0
]
if
key
==
'attributes'
:
self
.
assertEqual
(
size
,
pad_shape
)
for
_
,
v
in
actual_result
[
key
].
items
():
pad_shape
=
v
.
shape
[
0
]
self
.
assertEqual
(
size
,
pad_shape
)
else
:
pad_shape
=
actual_result
[
key
].
shape
[
0
]
self
.
assertEqual
(
size
,
pad_shape
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/vision/beta/modeling/retinanet_model.py
View file @
fd67924c
...
@@ -126,7 +126,7 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -126,7 +126,7 @@ class RetinaNetModel(tf.keras.Model):
'box_outputs'
:
raw_boxes
,
'box_outputs'
:
raw_boxes
,
}
}
if
raw_attributes
:
if
raw_attributes
:
outputs
.
update
({
'att_outputs'
:
raw_attributes
})
outputs
.
update
({
'att
ribute
_outputs'
:
raw_attributes
})
return
outputs
return
outputs
else
:
else
:
# Generate anchor boxes for this batch if not provided.
# Generate anchor boxes for this batch if not provided.
...
@@ -166,7 +166,7 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -166,7 +166,7 @@ class RetinaNetModel(tf.keras.Model):
if
raw_attributes
:
if
raw_attributes
:
outputs
.
update
({
outputs
.
update
({
'att_outputs'
:
raw_attributes
,
'att
ribute
_outputs'
:
raw_attributes
,
'detection_attributes'
:
final_results
[
'detection_attributes'
],
'detection_attributes'
:
final_results
[
'detection_attributes'
],
})
})
return
outputs
return
outputs
...
...
official/vision/beta/modeling/retinanet_model_test.py
View file @
fd67924c
...
@@ -223,7 +223,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -223,7 +223,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
4
*
num_anchors_per_location
4
*
num_anchors_per_location
],
box_outputs
[
str
(
level
)].
numpy
().
shape
)
],
box_outputs
[
str
(
level
)].
numpy
().
shape
)
if
has_att_heads
:
if
has_att_heads
:
att_outputs
=
model_outputs
[
'att_outputs'
]
att_outputs
=
model_outputs
[
'att
ribute
_outputs'
]
for
att
in
att_outputs
.
values
():
for
att
in
att_outputs
.
values
():
self
.
assertAllEqual
([
self
.
assertAllEqual
([
2
,
image_size
[
0
]
//
2
**
level
,
image_size
[
1
]
//
2
**
level
,
2
,
image_size
[
0
]
//
2
**
level
,
image_size
[
1
]
//
2
**
level
,
...
...
official/vision/beta/ops/anchor.py
View file @
fd67924c
...
@@ -140,7 +140,11 @@ class AnchorLabeler(object):
...
@@ -140,7 +140,11 @@ class AnchorLabeler(object):
force_match_for_each_col
=
True
)
force_match_for_each_col
=
True
)
self
.
box_coder
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
self
.
box_coder
=
faster_rcnn_box_coder
.
FasterRcnnBoxCoder
()
def
label_anchors
(
self
,
anchor_boxes
,
gt_boxes
,
gt_labels
):
def
label_anchors
(
self
,
anchor_boxes
,
gt_boxes
,
gt_labels
,
gt_attributes
=
None
):
"""Labels anchors with ground truth inputs.
"""Labels anchors with ground truth inputs.
Args:
Args:
...
@@ -150,6 +154,9 @@ class AnchorLabeler(object):
...
@@ -150,6 +154,9 @@ class AnchorLabeler(object):
For each row, it stores [y0, x0, y1, x1] for four corners of a box.
For each row, it stores [y0, x0, y1, x1] for four corners of a box.
gt_labels: A integer tensor with shape [N, 1] representing groundtruth
gt_labels: A integer tensor with shape [N, 1] representing groundtruth
classes.
classes.
gt_attributes: If not None, a dict of (name, gt_attribute) pairs.
`gt_attribute` is a float tensor with shape [N, attribute_size]
representing groundtruth attributes.
Returns:
Returns:
cls_targets_dict: ordered dictionary with keys
cls_targets_dict: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
[min_level, min_level+1, ..., max_level]. The values are tensor with
...
@@ -160,6 +167,12 @@ class AnchorLabeler(object):
...
@@ -160,6 +167,12 @@ class AnchorLabeler(object):
shape [height_l, width_l, num_anchors_per_location * 4]. The height_l
shape [height_l, width_l, num_anchors_per_location * 4]. The height_l
and width_l represent the dimension of bounding box regression output at
and width_l represent the dimension of bounding box regression output at
l-th level.
l-th level.
attribute_targets_dict: a dict with (name, attribute_targets) pairs. Each
`attribute_targets` represents an ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, num_anchors_per_location * attribute_size].
The height_l and width_l represent the dimension of attribute prediction
output at l-th level.
cls_weights: A flattened Tensor with shape [batch_size, num_anchors], that
cls_weights: A flattened Tensor with shape [batch_size, num_anchors], that
serves as masking / sample weight for classification loss. Its value
serves as masking / sample weight for classification loss. Its value
is 1.0 for positive and negative matched anchors, and 0.0 for ignored
is 1.0 for positive and negative matched anchors, and 0.0 for ignored
...
@@ -175,11 +188,19 @@ class AnchorLabeler(object):
...
@@ -175,11 +188,19 @@ class AnchorLabeler(object):
flattened_anchor_boxes
=
tf
.
concat
(
flattened_anchor_boxes
,
axis
=
0
)
flattened_anchor_boxes
=
tf
.
concat
(
flattened_anchor_boxes
,
axis
=
0
)
similarity_matrix
=
self
.
similarity_calc
(
flattened_anchor_boxes
,
gt_boxes
)
similarity_matrix
=
self
.
similarity_calc
(
flattened_anchor_boxes
,
gt_boxes
)
match_indices
,
match_indicators
=
self
.
matcher
(
similarity_matrix
)
match_indices
,
match_indicators
=
self
.
matcher
(
similarity_matrix
)
mask
=
tf
.
less_equal
(
match_indicators
,
0
)
mask
=
tf
.
less_equal
(
match_indicators
,
0
)
cls_mask
=
tf
.
expand_dims
(
mask
,
-
1
)
cls_mask
=
tf
.
expand_dims
(
mask
,
-
1
)
cls_targets
=
self
.
target_gather
(
gt_labels
,
match_indices
,
cls_mask
,
-
1
)
cls_targets
=
self
.
target_gather
(
gt_labels
,
match_indices
,
cls_mask
,
-
1
)
box_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
4
])
box_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
4
])
box_targets
=
self
.
target_gather
(
gt_boxes
,
match_indices
,
box_mask
)
box_targets
=
self
.
target_gather
(
gt_boxes
,
match_indices
,
box_mask
)
att_targets
=
{}
if
gt_attributes
:
for
k
,
v
in
gt_attributes
.
items
():
att_size
=
v
.
get_shape
().
as_list
()[
-
1
]
att_mask
=
tf
.
tile
(
cls_mask
,
[
1
,
att_size
])
att_targets
[
k
]
=
self
.
target_gather
(
v
,
match_indices
,
att_mask
,
-
1
)
weights
=
tf
.
squeeze
(
tf
.
ones_like
(
gt_labels
,
dtype
=
tf
.
float32
),
-
1
)
weights
=
tf
.
squeeze
(
tf
.
ones_like
(
gt_labels
,
dtype
=
tf
.
float32
),
-
1
)
box_weights
=
self
.
target_gather
(
weights
,
match_indices
,
mask
)
box_weights
=
self
.
target_gather
(
weights
,
match_indices
,
mask
)
ignore_mask
=
tf
.
equal
(
match_indicators
,
-
2
)
ignore_mask
=
tf
.
equal
(
match_indicators
,
-
2
)
...
@@ -191,8 +212,11 @@ class AnchorLabeler(object):
...
@@ -191,8 +212,11 @@ class AnchorLabeler(object):
# Unpacks labels into multi-level representations.
# Unpacks labels into multi-level representations.
cls_targets_dict
=
unpack_targets
(
cls_targets
,
anchor_boxes
)
cls_targets_dict
=
unpack_targets
(
cls_targets
,
anchor_boxes
)
box_targets_dict
=
unpack_targets
(
box_targets
,
anchor_boxes
)
box_targets_dict
=
unpack_targets
(
box_targets
,
anchor_boxes
)
attribute_targets_dict
=
{}
for
k
,
v
in
att_targets
.
items
():
attribute_targets_dict
[
k
]
=
unpack_targets
(
v
,
anchor_boxes
)
return
cls_targets_dict
,
box_targets_dict
,
cls_weights
,
box_weights
return
cls_targets_dict
,
box_targets_dict
,
attribute_targets_dict
,
cls_weights
,
box_weights
class
RpnAnchorLabeler
(
AnchorLabeler
):
class
RpnAnchorLabeler
(
AnchorLabeler
):
...
...
official/vision/beta/ops/anchor_test.py
View file @
fd67924c
...
@@ -107,12 +107,15 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -107,12 +107,15 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertEqual
(
expected_boxes
,
boxes
.
tolist
())
self
.
assertEqual
(
expected_boxes
,
boxes
.
tolist
())
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
3
,
6
,
2
,
[
1.0
],
2.0
),
(
3
,
6
,
2
,
[
1.0
],
2.0
,
False
),
(
3
,
6
,
2
,
[
1.0
],
2.0
,
True
),
)
)
def
testLabelAnchors
(
self
,
min_level
,
max_level
,
num_scales
,
def
testLabelAnchors
(
self
,
min_level
,
max_level
,
num_scales
,
aspect_ratios
,
a
spect_ratios
,
anchor_siz
e
):
a
nchor_size
,
has_attribut
e
):
input_size
=
[
512
,
512
]
input_size
=
[
512
,
512
]
ground_truth_class_id
=
2
ground_truth_class_id
=
2
attribute_name
=
'depth'
ground_truth_depth
=
3.0
# The matched anchors are the anchors used as ground truth and the anchors
# The matched anchors are the anchors used as ground truth and the anchors
# at the next octave scale on the same location.
# at the next octave scale on the same location.
...
@@ -126,9 +129,13 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -126,9 +129,13 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
# two anchors with two intermediate scales at the same location.
# two anchors with two intermediate scales at the same location.
gt_boxes
=
anchor_boxes
[
'3'
][
0
:
1
,
0
,
0
:
4
]
gt_boxes
=
anchor_boxes
[
'3'
][
0
:
1
,
0
,
0
:
4
]
gt_classes
=
tf
.
constant
([[
ground_truth_class_id
]],
dtype
=
tf
.
float32
)
gt_classes
=
tf
.
constant
([[
ground_truth_class_id
]],
dtype
=
tf
.
float32
)
(
cls_targets
,
box_targets
,
_
,
gt_attributes
=
{
box_weights
)
=
anchor_labeler
.
label_anchors
(
attribute_name
:
tf
.
constant
([[
ground_truth_depth
]],
dtype
=
tf
.
float32
)
anchor_boxes
,
gt_boxes
,
gt_classes
)
}
if
has_attribute
else
{}
(
cls_targets
,
box_targets
,
att_targets
,
_
,
box_weights
)
=
anchor_labeler
.
label_anchors
(
anchor_boxes
,
gt_boxes
,
gt_classes
,
gt_attributes
)
for
k
,
v
in
cls_targets
.
items
():
for
k
,
v
in
cls_targets
.
items
():
cls_targets
[
k
]
=
v
.
numpy
()
cls_targets
[
k
]
=
v
.
numpy
()
...
@@ -142,6 +149,17 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -142,6 +149,17 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
# Two anchor boxes on min_level got matched to the gt_boxes.
# Two anchor boxes on min_level got matched to the gt_boxes.
self
.
assertAllClose
(
tf
.
reduce_sum
(
box_weights
),
2
)
self
.
assertAllClose
(
tf
.
reduce_sum
(
box_weights
),
2
)
if
has_attribute
:
self
.
assertIn
(
attribute_name
,
att_targets
)
for
k
,
v
in
att_targets
[
attribute_name
].
items
():
att_targets
[
attribute_name
][
k
]
=
v
.
numpy
()
anchor_locations
=
np
.
vstack
(
np
.
where
(
att_targets
[
attribute_name
][
str
(
min_level
)]
>
-
1
)).
transpose
()
self
.
assertAllClose
(
expected_anchor_locations
,
anchor_locations
)
else
:
self
.
assertEmpty
(
att_targets
)
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
3
,
7
,
[.
5
,
1.
,
2.
],
2
,
8
,
(
256
,
256
)),
(
3
,
7
,
[.
5
,
1.
,
2.
],
2
,
8
,
(
256
,
256
)),
(
3
,
8
,
[
1.
],
3
,
32
,
(
512
,
512
)),
(
3
,
8
,
[
1.
],
3
,
32
,
(
512
,
512
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment