Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4bf492a8
Commit
4bf492a8
authored
Jun 04, 2021
by
Ronny Votel
Committed by
TF Object Detection Team
Jun 04, 2021
Browse files
Updating the centernet mask target assigner.
PiperOrigin-RevId: 377511299
parent
33d1ce83
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
15 deletions
+88
-15
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+43
-5
research/object_detection/core/target_assigner_test.py
research/object_detection/core/target_assigner_test.py
+30
-7
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+15
-3
No files found.
research/object_detection/core/target_assigner.py
View file @
4bf492a8
...
@@ -2001,8 +2001,8 @@ class CenterNetMaskTargetAssigner(object):
...
@@ -2001,8 +2001,8 @@ class CenterNetMaskTargetAssigner(object):
self
.
_stride
=
stride
self
.
_stride
=
stride
def
assign_segmentation_targets
(
def
assign_segmentation_targets
(
self
,
gt_masks_list
,
gt_classes_list
,
self
,
gt_masks_list
,
gt_classes_list
,
gt_boxes_list
=
None
,
mask_resize_method
=
ResizeMethod
.
BILINEAR
):
gt_mask_weights_list
=
None
,
mask_resize_method
=
ResizeMethod
.
BILINEAR
):
"""Computes the segmentation targets.
"""Computes the segmentation targets.
This utility produces a semantic segmentation mask for each class, starting
This utility produces a semantic segmentation mask for each class, starting
...
@@ -2016,15 +2016,25 @@ class CenterNetMaskTargetAssigner(object):
...
@@ -2016,15 +2016,25 @@ class CenterNetMaskTargetAssigner(object):
gt_classes_list: A list of float tensors with shape [num_boxes,
gt_classes_list: A list of float tensors with shape [num_boxes,
num_classes] representing the one-hot encoded class labels for each box
num_classes] representing the one-hot encoded class labels for each box
in the gt_boxes_list.
in the gt_boxes_list.
gt_boxes_list: An optional list of float tensors with shape [num_boxes, 4]
with normalized boxes corresponding to each mask. The boxes are used to
spatially allocate mask weights.
gt_mask_weights_list: An optional list of float tensors with shape
[num_boxes] with weights for each mask. If a mask has a zero weight, it
indicates that the box region associated with the mask should not
contribute to the loss. If not provided, will use a per-pixel weight of
1.
mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use
mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use
when resizing masks from input resolution to output resolution.
when resizing masks from input resolution to output resolution.
Returns:
Returns:
segmentation_targets: An int32 tensor of size [batch_size, output_height,
segmentation_targets: An int32 tensor of size [batch_size, output_height,
output_width, num_classes] representing the class of each location in
output_width, num_classes] representing the class of each location in
the output space.
the output space.
segmentation_weight: A float32 tensor of size [batch_size, output_height,
output_width] indicating the loss weight to apply at each location.
"""
"""
# TODO(ronnyvotel): Handle groundtruth weights.
_
,
num_classes
=
shape_utils
.
combined_static_and_dynamic_shape
(
_
,
num_classes
=
shape_utils
.
combined_static_and_dynamic_shape
(
gt_classes_list
[
0
])
gt_classes_list
[
0
])
...
@@ -2033,8 +2043,35 @@ class CenterNetMaskTargetAssigner(object):
...
@@ -2033,8 +2043,35 @@ class CenterNetMaskTargetAssigner(object):
output_height
=
tf
.
maximum
(
input_height
//
self
.
_stride
,
1
)
output_height
=
tf
.
maximum
(
input_height
//
self
.
_stride
,
1
)
output_width
=
tf
.
maximum
(
input_width
//
self
.
_stride
,
1
)
output_width
=
tf
.
maximum
(
input_width
//
self
.
_stride
,
1
)
if
gt_boxes_list
is
None
:
gt_boxes_list
=
[
None
]
*
len
(
gt_masks_list
)
if
gt_mask_weights_list
is
None
:
gt_mask_weights_list
=
[
None
]
*
len
(
gt_masks_list
)
segmentation_targets_list
=
[]
segmentation_targets_list
=
[]
for
gt_masks
,
gt_classes
in
zip
(
gt_masks_list
,
gt_classes_list
):
segmentation_weights_list
=
[]
for
gt_boxes
,
gt_masks
,
gt_mask_weights
,
gt_classes
in
zip
(
gt_boxes_list
,
gt_masks_list
,
gt_mask_weights_list
,
gt_classes_list
):
if
gt_boxes
is
not
None
and
gt_mask_weights
is
not
None
:
boxes
=
box_list
.
BoxList
(
gt_boxes
)
# Convert the box coordinates to absolute output image dimension space.
boxes_absolute
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
output_height
,
output_width
)
# Generate a segmentation weight that applies mask weights in object
# regions.
blackout
=
gt_mask_weights
<=
0
segmentation_weight_for_image
=
(
ta_utils
.
blackout_pixel_weights_by_box_regions
(
output_height
,
output_width
,
boxes_absolute
.
get
(),
blackout
,
weights
=
gt_mask_weights
))
segmentation_weights_list
.
append
(
segmentation_weight_for_image
)
else
:
segmentation_weights_list
.
append
(
tf
.
ones
((
output_height
,
output_width
),
dtype
=
tf
.
float32
))
gt_masks
=
_resize_masks
(
gt_masks
,
output_height
,
output_width
,
gt_masks
=
_resize_masks
(
gt_masks
,
output_height
,
output_width
,
mask_resize_method
)
mask_resize_method
)
gt_masks
=
gt_masks
[:,
:,
:,
tf
.
newaxis
]
gt_masks
=
gt_masks
[:,
:,
:,
tf
.
newaxis
]
...
@@ -2047,7 +2084,8 @@ class CenterNetMaskTargetAssigner(object):
...
@@ -2047,7 +2084,8 @@ class CenterNetMaskTargetAssigner(object):
segmentation_targets_list
.
append
(
segmentations_for_image
)
segmentation_targets_list
.
append
(
segmentations_for_image
)
segmentation_target
=
tf
.
stack
(
segmentation_targets_list
,
axis
=
0
)
segmentation_target
=
tf
.
stack
(
segmentation_targets_list
,
axis
=
0
)
return
segmentation_target
segmentation_weight
=
tf
.
stack
(
segmentation_weights_list
,
axis
=
0
)
return
segmentation_target
,
segmentation_weight
class
CenterNetDensePoseTargetAssigner
(
object
):
class
CenterNetDensePoseTargetAssigner
(
object
):
...
...
research/object_detection/core/target_assigner_test.py
View file @
4bf492a8
...
@@ -2090,13 +2090,31 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
...
@@ -2090,13 +2090,31 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
tf
.
constant
([[
0.
,
1.
,
0.
],
tf
.
constant
([[
0.
,
1.
,
0.
],
[
0.
,
1.
,
0.
]],
dtype
=
tf
.
float32
)
[
0.
,
1.
,
0.
]],
dtype
=
tf
.
float32
)
]
]
gt_boxes_list
=
[
# Example 0.
tf
.
constant
([[
0.0
,
0.0
,
0.5
,
0.5
],
[
0.0
,
0.5
,
0.5
,
1.0
],
[
0.0
,
0.0
,
1.0
,
1.0
]],
dtype
=
tf
.
float32
),
# Example 1.
tf
.
constant
([[
0.0
,
0.0
,
1.0
,
1.0
],
[
0.5
,
0.0
,
1.0
,
0.5
]],
dtype
=
tf
.
float32
)
]
gt_mask_weights_list
=
[
# Example 0.
tf
.
constant
([
0.0
,
1.0
,
1.0
],
dtype
=
tf
.
float32
),
# Example 1.
tf
.
constant
([
1.0
,
1.0
],
dtype
=
tf
.
float32
)
]
cn_assigner
=
targetassigner
.
CenterNetMaskTargetAssigner
(
stride
=
2
)
cn_assigner
=
targetassigner
.
CenterNetMaskTargetAssigner
(
stride
=
2
)
segmentation_target
=
cn_assigner
.
assign_segmentation_targets
(
segmentation_target
,
segmentation_weight
=
(
cn_assigner
.
assign_segmentation_targets
(
gt_masks_list
=
gt_masks_list
,
gt_masks_list
=
gt_masks_list
,
gt_classes_list
=
gt_classes_list
,
gt_classes_list
=
gt_classes_list
,
mask_resize_method
=
targetassigner
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
gt_boxes_list
=
gt_boxes_list
,
return
segmentation_target
gt_mask_weights_list
=
gt_mask_weights_list
,
segmentation_target
=
self
.
execute
(
graph_fn
,
[])
mask_resize_method
=
targetassigner
.
ResizeMethod
.
NEAREST_NEIGHBOR
))
return
segmentation_target
,
segmentation_weight
segmentation_target
,
segmentation_weight
=
self
.
execute
(
graph_fn
,
[])
expected_seg_target
=
np
.
array
([
expected_seg_target
=
np
.
array
([
# Example 0 [[class 0, class 1], [background, class 0]]
# Example 0 [[class 0, class 1], [background, class 0]]
...
@@ -2108,13 +2126,18 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
...
@@ -2108,13 +2126,18 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
],
dtype
=
np
.
float32
)
],
dtype
=
np
.
float32
)
np
.
testing
.
assert_array_almost_equal
(
np
.
testing
.
assert_array_almost_equal
(
expected_seg_target
,
segmentation_target
)
expected_seg_target
,
segmentation_target
)
expected_seg_weight
=
np
.
array
([
[[
0
,
1
],
[
1
,
1
]],
[[
1
,
1
],
[
1
,
1
]]],
dtype
=
np
.
float32
)
np
.
testing
.
assert_array_almost_equal
(
expected_seg_weight
,
segmentation_weight
)
def
test_assign_segmentation_targets_no_objects
(
self
):
def
test_assign_segmentation_targets_no_objects
(
self
):
def
graph_fn
():
def
graph_fn
():
gt_masks_list
=
[
tf
.
zeros
((
0
,
5
,
5
))]
gt_masks_list
=
[
tf
.
zeros
((
0
,
5
,
5
))]
gt_classes_list
=
[
tf
.
zeros
((
0
,
10
))]
gt_classes_list
=
[
tf
.
zeros
((
0
,
10
))]
cn_assigner
=
targetassigner
.
CenterNetMaskTargetAssigner
(
stride
=
1
)
cn_assigner
=
targetassigner
.
CenterNetMaskTargetAssigner
(
stride
=
1
)
segmentation_target
=
cn_assigner
.
assign_segmentation_targets
(
segmentation_target
,
_
=
cn_assigner
.
assign_segmentation_targets
(
gt_masks_list
=
gt_masks_list
,
gt_masks_list
=
gt_masks_list
,
gt_classes_list
=
gt_classes_list
,
gt_classes_list
=
gt_classes_list
,
mask_resize_method
=
targetassigner
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
mask_resize_method
=
targetassigner
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
4bf492a8
...
@@ -2979,20 +2979,32 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2979,20 +2979,32 @@ class CenterNetMetaArch(model.DetectionModel):
Returns:
Returns:
A float scalar tensor representing the mask loss.
A float scalar tensor representing the mask loss.
"""
"""
gt_boxes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
boxes
)
gt_masks_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
masks
)
gt_masks_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
masks
)
gt_mask_weights_list
=
None
if
self
.
groundtruth_has_field
(
fields
.
BoxListFields
.
mask_weights
):
gt_mask_weights_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
mask_weights
)
gt_classes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
classes
)
gt_classes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
classes
)
# Convert the groundtruth to targets.
# Convert the groundtruth to targets.
assigner
=
self
.
_target_assigner_dict
[
SEGMENTATION_TASK
]
assigner
=
self
.
_target_assigner_dict
[
SEGMENTATION_TASK
]
heatmap_targets
=
assigner
.
assign_segmentation_targets
(
heatmap_targets
,
heatmap_weight
=
assigner
.
assign_segmentation_targets
(
gt_masks_list
=
gt_masks_list
,
gt_masks_list
=
gt_masks_list
,
gt_classes_list
=
gt_classes_list
)
gt_classes_list
=
gt_classes_list
,
gt_boxes_list
=
gt_boxes_list
,
gt_mask_weights_list
=
gt_mask_weights_list
)
flattened_heatmap_targets
=
_flatten_spatial_dimensions
(
heatmap_targets
)
flattened_heatmap_targets
=
_flatten_spatial_dimensions
(
heatmap_targets
)
flattened_heatmap_mask
=
_flatten_spatial_dimensions
(
heatmap_weight
[:,
:,
:,
tf
.
newaxis
])
per_pixel_weights
*=
flattened_heatmap_mask
loss
=
0.0
loss
=
0.0
mask_loss_fn
=
self
.
_mask_params
.
classification_loss
mask_loss_fn
=
self
.
_mask_params
.
classification_loss
total_pixels_in_loss
=
tf
.
reduce_sum
(
per_pixel_weights
)
total_pixels_in_loss
=
tf
.
math
.
maximum
(
tf
.
reduce_sum
(
per_pixel_weights
),
1
)
# Loop through each feature output head.
# Loop through each feature output head.
for
pred
in
segmentation_predictions
:
for
pred
in
segmentation_predictions
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment