Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1cbb7a7d
Commit
1cbb7a7d
authored
Jul 19, 2021
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Jul 19, 2021
Browse files
Add box consistency loss in DeepMAC.
PiperOrigin-RevId: 385646058
parent
76640072
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
304 additions
and
71 deletions
+304
-71
research/object_detection/meta_architectures/deepmac_meta_arch.py
.../object_detection/meta_architectures/deepmac_meta_arch.py
+203
-59
research/object_detection/meta_architectures/deepmac_meta_arch_test.py
...ct_detection/meta_architectures/deepmac_meta_arch_test.py
+97
-12
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+4
-0
No files found.
research/object_detection/meta_architectures/deepmac_meta_arch.py
View file @
1cbb7a7d
...
@@ -26,6 +26,7 @@ from object_detection.utils import spatial_transform_ops
...
@@ -26,6 +26,7 @@ from object_detection.utils import spatial_transform_ops
INSTANCE_EMBEDDING
=
'INSTANCE_EMBEDDING'
INSTANCE_EMBEDDING
=
'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING
=
'PIXEL_EMBEDDING'
PIXEL_EMBEDDING
=
'PIXEL_EMBEDDING'
DEEP_MASK_ESTIMATION
=
'deep_mask_estimation'
DEEP_MASK_ESTIMATION
=
'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY
=
'deep_mask_box_consistency'
LOSS_KEY_PREFIX
=
center_net_meta_arch
.
LOSS_KEY_PREFIX
LOSS_KEY_PREFIX
=
center_net_meta_arch
.
LOSS_KEY_PREFIX
...
@@ -35,7 +36,7 @@ class DeepMACParams(
...
@@ -35,7 +36,7 @@ class DeepMACParams(
'allowed_masked_classes_ids'
,
'mask_size'
,
'mask_num_subsamples'
,
'allowed_masked_classes_ids'
,
'mask_size'
,
'mask_num_subsamples'
,
'use_xy'
,
'network_type'
,
'use_instance_embedding'
,
'num_init_channels'
,
'use_xy'
,
'network_type'
,
'use_instance_embedding'
,
'num_init_channels'
,
'predict_full_resolution_masks'
,
'postprocess_crop_size'
,
'predict_full_resolution_masks'
,
'postprocess_crop_size'
,
'max_roi_jitter_ratio'
,
'roi_jitter_mode'
'max_roi_jitter_ratio'
,
'roi_jitter_mode'
,
'box_consistency_loss_weight'
])):
])):
"""Class holding the DeepMAC network configutration."""
"""Class holding the DeepMAC network configutration."""
...
@@ -46,7 +47,7 @@ class DeepMACParams(
...
@@ -46,7 +47,7 @@ class DeepMACParams(
mask_num_subsamples
,
use_xy
,
network_type
,
use_instance_embedding
,
mask_num_subsamples
,
use_xy
,
network_type
,
use_instance_embedding
,
num_init_channels
,
predict_full_resolution_masks
,
num_init_channels
,
predict_full_resolution_masks
,
postprocess_crop_size
,
max_roi_jitter_ratio
,
postprocess_crop_size
,
max_roi_jitter_ratio
,
roi_jitter_mode
):
roi_jitter_mode
,
box_consistency_loss_weight
):
return
super
(
DeepMACParams
,
return
super
(
DeepMACParams
,
cls
).
__new__
(
cls
,
classification_loss
,
dim
,
cls
).
__new__
(
cls
,
classification_loss
,
dim
,
task_loss_weight
,
pixel_embedding_dim
,
task_loss_weight
,
pixel_embedding_dim
,
...
@@ -55,7 +56,7 @@ class DeepMACParams(
...
@@ -55,7 +56,7 @@ class DeepMACParams(
use_instance_embedding
,
num_init_channels
,
use_instance_embedding
,
num_init_channels
,
predict_full_resolution_masks
,
predict_full_resolution_masks
,
postprocess_crop_size
,
max_roi_jitter_ratio
,
postprocess_crop_size
,
max_roi_jitter_ratio
,
roi_jitter_mode
)
roi_jitter_mode
,
box_consistency_loss_weight
)
def
subsample_instances
(
classes
,
weights
,
boxes
,
masks
,
num_subsamples
):
def
subsample_instances
(
classes
,
weights
,
boxes
,
masks
,
num_subsamples
):
...
@@ -206,6 +207,61 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
...
@@ -206,6 +207,61 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
)
)
def
crop_and_resize_feature_map
(
features
,
boxes
,
size
):
"""Crop and resize regions from a single feature map given a set of boxes.
Args:
features: A [H, W, C] float tensor.
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized
features.
"""
return
spatial_transform_ops
.
matmul_crop_and_resize
(
features
[
tf
.
newaxis
],
boxes
[
tf
.
newaxis
],
[
size
,
size
])[
0
]
def
crop_and_resize_instance_masks
(
masks
,
boxes
,
mask_size
):
"""Crop and resize each mask according to the given boxes.
Args:
masks: A [N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks.
Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized
instance masks.
"""
cropped_masks
=
spatial_transform_ops
.
matmul_crop_and_resize
(
masks
[:,
:,
:,
tf
.
newaxis
],
boxes
[:,
tf
.
newaxis
,
:],
[
mask_size
,
mask_size
])
cropped_masks
=
tf
.
squeeze
(
cropped_masks
,
axis
=
[
1
,
4
])
return
cropped_masks
def
fill_boxes
(
boxes
,
height
,
width
):
"""Fills the area included in the box."""
blist
=
box_list
.
BoxList
(
boxes
)
blist
=
box_list_ops
.
to_absolute_coordinates
(
blist
,
height
,
width
)
boxes
=
blist
.
get
()
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
unstack
(
boxes
[:,
tf
.
newaxis
,
tf
.
newaxis
,
:],
4
,
axis
=
3
)
ygrid
,
xgrid
=
tf
.
meshgrid
(
tf
.
range
(
height
),
tf
.
range
(
width
),
indexing
=
'ij'
)
ygrid
,
xgrid
=
tf
.
cast
(
ygrid
,
tf
.
float32
),
tf
.
cast
(
xgrid
,
tf
.
float32
)
ygrid
,
xgrid
=
ygrid
[
tf
.
newaxis
,
:,
:],
xgrid
[
tf
.
newaxis
,
:,
:]
filled_boxes
=
tf
.
logical_and
(
tf
.
logical_and
(
ygrid
>=
ymin
,
ygrid
<=
ymax
),
tf
.
logical_and
(
xgrid
>=
xmin
,
xgrid
<=
xmax
))
return
tf
.
cast
(
filled_boxes
,
tf
.
float32
)
class
ResNetMaskNetwork
(
tf
.
keras
.
layers
.
Layer
):
class
ResNetMaskNetwork
(
tf
.
keras
.
layers
.
Layer
):
"""A small wrapper around ResNet blocks to predict masks."""
"""A small wrapper around ResNet blocks to predict masks."""
...
@@ -379,7 +435,8 @@ def deepmac_proto_to_params(deepmac_config):
...
@@ -379,7 +435,8 @@ def deepmac_proto_to_params(deepmac_config):
deepmac_config
.
predict_full_resolution_masks
,
deepmac_config
.
predict_full_resolution_masks
,
postprocess_crop_size
=
deepmac_config
.
postprocess_crop_size
,
postprocess_crop_size
=
deepmac_config
.
postprocess_crop_size
,
max_roi_jitter_ratio
=
deepmac_config
.
max_roi_jitter_ratio
,
max_roi_jitter_ratio
=
deepmac_config
.
max_roi_jitter_ratio
,
roi_jitter_mode
=
jitter_mode
roi_jitter_mode
=
jitter_mode
,
box_consistency_loss_weight
=
deepmac_config
.
box_consistency_loss_weight
)
)
...
@@ -402,6 +459,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -402,6 +459,13 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Constructs the super class with object center & detection params only."""
"""Constructs the super class with object center & detection params only."""
self
.
_deepmac_params
=
deepmac_params
self
.
_deepmac_params
=
deepmac_params
if
(
self
.
_deepmac_params
.
predict_full_resolution_masks
and
self
.
_deepmac_params
.
max_roi_jitter_ratio
>
0.0
):
raise
ValueError
(
'Jittering is not supported for full res masks.'
)
if
self
.
_deepmac_params
.
mask_num_subsamples
>
0
:
raise
ValueError
(
'Subsampling masks is currently not supported.'
)
super
(
DeepMACMetaArch
,
self
).
__init__
(
super
(
DeepMACMetaArch
,
self
).
__init__
(
is_training
=
is_training
,
add_summaries
=
add_summaries
,
is_training
=
is_training
,
add_summaries
=
add_summaries
,
num_classes
=
num_classes
,
feature_extractor
=
feature_extractor
,
num_classes
=
num_classes
,
feature_extractor
=
feature_extractor
,
...
@@ -462,21 +526,34 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -462,21 +526,34 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
pixel_embedding
=
pixel_embedding
[
tf
.
newaxis
,
:,
:,
:]
pixel_embedding
=
pixel_embedding
[
tf
.
newaxis
,
:,
:,
:]
pixel_embeddings_processed
=
tf
.
tile
(
pixel_embedding
,
pixel_embeddings_processed
=
tf
.
tile
(
pixel_embedding
,
[
num_instances
,
1
,
1
,
1
])
[
num_instances
,
1
,
1
,
1
])
image_shape
=
tf
.
shape
(
pixel_embeddings_processed
)
image_height
,
image_width
=
image_shape
[
1
],
image_shape
[
2
]
y_grid
,
x_grid
=
tf
.
meshgrid
(
tf
.
linspace
(
0.0
,
1.0
,
image_height
),
tf
.
linspace
(
0.0
,
1.0
,
image_width
),
indexing
=
'ij'
)
blist
=
box_list
.
BoxList
(
boxes
)
ycenter
,
xcenter
,
_
,
_
=
blist
.
get_center_coordinates_and_sizes
()
y_grid
=
y_grid
[
tf
.
newaxis
,
:,
:]
x_grid
=
x_grid
[
tf
.
newaxis
,
:,
:]
y_grid
-=
ycenter
[:,
tf
.
newaxis
,
tf
.
newaxis
]
x_grid
-=
xcenter
[:,
tf
.
newaxis
,
tf
.
newaxis
]
coords
=
tf
.
stack
([
y_grid
,
x_grid
],
axis
=
3
)
else
:
else
:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_cropped
=
spatial_transform_ops
.
matmul_crop_and_resize
(
pixel_embeddings_processed
=
crop_and_resize_feature_map
(
pixel_embedding
[
tf
.
newaxis
],
boxes
[
tf
.
newaxis
],
pixel_embedding
,
boxes
,
mask_size
)
[
mask_size
,
mask_size
])
mask_shape
=
tf
.
shape
(
pixel_embeddings_processed
)
pixel_embeddings_processed
=
pixel_embeddings_cropped
[
0
]
mask_height
,
mask_width
=
mask_shape
[
1
],
mask_shape
[
2
]
y_grid
,
x_grid
=
tf
.
meshgrid
(
tf
.
linspace
(
-
1.0
,
1.0
,
mask_height
),
mask_shape
=
tf
.
shape
(
pixel_embeddings_processed
)
tf
.
linspace
(
-
1.0
,
1.0
,
mask_width
),
mask_height
,
mask_width
=
mask_shape
[
1
],
mask_shape
[
2
]
indexing
=
'ij'
)
y_grid
,
x_grid
=
tf
.
meshgrid
(
tf
.
linspace
(
-
1.0
,
1.0
,
mask_height
),
tf
.
linspace
(
-
1.0
,
1.0
,
mask_width
),
coords
=
tf
.
stack
([
y_grid
,
x_grid
],
axis
=
2
)
indexing
=
'ij'
)
coords
=
coords
[
tf
.
newaxis
,
:,
:,
:]
coords
=
tf
.
stack
([
y_grid
,
x_grid
],
axis
=
2
)
coords
=
tf
.
tile
(
coords
,
[
num_instances
,
1
,
1
,
1
])
coords
=
coords
[
tf
.
newaxis
,
:,
:,
:]
coords
=
tf
.
tile
(
coords
,
[
num_instances
,
1
,
1
,
1
])
if
self
.
_deepmac_params
.
use_xy
:
if
self
.
_deepmac_params
.
use_xy
:
return
tf
.
concat
([
coords
,
pixel_embeddings_processed
],
axis
=
3
)
return
tf
.
concat
([
coords
,
pixel_embeddings_processed
],
axis
=
3
)
...
@@ -528,11 +605,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -528,11 +605,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if
self
.
_deepmac_params
.
predict_full_resolution_masks
:
if
self
.
_deepmac_params
.
predict_full_resolution_masks
:
return
masks
return
masks
else
:
else
:
cropped_masks
=
spatial_transform_ops
.
matmul_crop_and_resize
(
cropped_masks
=
crop_and_resize_instance_masks
(
masks
[:,
:,
:,
tf
.
newaxis
],
boxes
[:,
tf
.
newaxis
,
:],
masks
,
boxes
,
mask_size
)
[
mask_size
,
mask_size
])
cropped_masks
=
tf
.
stop_gradient
(
cropped_masks
)
cropped_masks
=
tf
.
stop_gradient
(
cropped_masks
)
cropped_masks
=
tf
.
squeeze
(
cropped_masks
,
axis
=
[
1
,
4
])
# TODO(vighneshb) should we discretize masks?
# TODO(vighneshb) should we discretize masks?
return
cropped_masks
return
cropped_masks
...
@@ -543,7 +618,64 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -543,7 +618,64 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return
resize_instance_masks
(
logits
,
(
height
,
width
))
return
resize_instance_masks
(
logits
,
(
height
,
width
))
def
_compute_per_instance_mask_loss
(
def
_compute_per_instance_mask_prediction_loss
(
self
,
boxes
,
mask_logits
,
mask_gt
):
num_instances
=
tf
.
shape
(
boxes
)[
0
]
mask_logits
=
self
.
_resize_logits_like_gt
(
mask_logits
,
mask_gt
)
mask_logits
=
tf
.
reshape
(
mask_logits
,
[
num_instances
,
-
1
,
1
])
mask_gt
=
tf
.
reshape
(
mask_gt
,
[
num_instances
,
-
1
,
1
])
loss
=
self
.
_deepmac_params
.
classification_loss
(
prediction_tensor
=
mask_logits
,
target_tensor
=
mask_gt
,
weights
=
tf
.
ones_like
(
mask_logits
))
# TODO(vighneshb) Make this configurable via config.
# Skip normalization for dice loss because the denominator term already
# does normalization.
if
isinstance
(
self
.
_deepmac_params
.
classification_loss
,
losses
.
WeightedDiceClassificationLoss
):
return
tf
.
reduce_sum
(
loss
,
axis
=
1
)
else
:
return
tf
.
reduce_mean
(
loss
,
axis
=
[
1
,
2
])
def
_compute_per_instance_box_consistency_loss
(
self
,
boxes_gt
,
boxes_for_crop
,
mask_logits
):
height
,
width
=
tf
.
shape
(
mask_logits
)[
1
],
tf
.
shape
(
mask_logits
)[
2
]
filled_boxes
=
fill_boxes
(
boxes_gt
,
height
,
width
)[:,
:,
:,
tf
.
newaxis
]
mask_logits
=
mask_logits
[:,
:,
:,
tf
.
newaxis
]
if
self
.
_deepmac_params
.
predict_full_resolution_masks
:
gt_crop
=
filled_boxes
[:,
:,
:,
0
]
pred_crop
=
mask_logits
[:,
:,
:,
0
]
else
:
gt_crop
=
crop_and_resize_instance_masks
(
filled_boxes
,
boxes_for_crop
,
self
.
_deepmac_params
.
mask_size
)
pred_crop
=
crop_and_resize_instance_masks
(
mask_logits
,
boxes_for_crop
,
self
.
_deepmac_params
.
mask_size
)
loss
=
0.0
for
axis
in
[
1
,
2
]:
pred_max
=
tf
.
reduce_max
(
pred_crop
,
axis
=
axis
)[:,
:,
tf
.
newaxis
]
gt_max
=
tf
.
reduce_max
(
gt_crop
,
axis
=
axis
)[:,
:,
tf
.
newaxis
]
axis_loss
=
self
.
_deepmac_params
.
classification_loss
(
prediction_tensor
=
pred_max
,
target_tensor
=
gt_max
,
weights
=
tf
.
ones_like
(
pred_max
))
loss
+=
axis_loss
# Skip normalization for dice loss because the denominator term already
# does normalization.
# TODO(vighneshb) Make this configurable via config.
if
isinstance
(
self
.
_deepmac_params
.
classification_loss
,
losses
.
WeightedDiceClassificationLoss
):
return
tf
.
reduce_sum
(
loss
,
axis
=
1
)
else
:
return
tf
.
reduce_mean
(
loss
,
axis
=
[
1
,
2
])
def
_compute_per_instance_deepmac_losses
(
self
,
boxes
,
masks
,
instance_embedding
,
pixel_embedding
):
self
,
boxes
,
masks
,
instance_embedding
,
pixel_embedding
):
"""Returns the mask loss per instance.
"""Returns the mask loss per instance.
...
@@ -558,40 +690,36 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -558,40 +690,36 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
pixel_embedding_size] float tensor containing the per-pixel embeddings.
pixel_embedding_size] float tensor containing the per-pixel embeddings.
Returns:
Returns:
mask_loss: A [num_instances] shaped float tensor containing the
mask_
prediction_
loss: A [num_instances] shaped float tensor containing the
mask loss for each instance.
mask loss for each instance.
"""
box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance.
num_instances
=
tf
.
shape
(
boxes
)[
0
]
"""
if
tf
.
keras
.
backend
.
learning_phase
():
if
tf
.
keras
.
backend
.
learning_phase
():
boxes
=
preprocessor
.
random_jitter_boxes
(
boxes
_for_crop
=
preprocessor
.
random_jitter_boxes
(
boxes
,
self
.
_deepmac_params
.
max_roi_jitter_ratio
,
boxes
,
self
.
_deepmac_params
.
max_roi_jitter_ratio
,
jitter_mode
=
self
.
_deepmac_params
.
roi_jitter_mode
)
jitter_mode
=
self
.
_deepmac_params
.
roi_jitter_mode
)
else
:
boxes_for_crop
=
boxes
mask_input
=
self
.
_get_mask_head_input
(
mask_input
=
self
.
_get_mask_head_input
(
boxes
,
pixel_embedding
)
boxes
_for_crop
,
pixel_embedding
)
instance_embeddings
=
self
.
_get_instance_embeddings
(
instance_embeddings
=
self
.
_get_instance_embeddings
(
boxes
,
instance_embedding
)
boxes_for_crop
,
instance_embedding
)
mask_logits
=
self
.
_mask_net
(
mask_logits
=
self
.
_mask_net
(
instance_embeddings
,
mask_input
,
instance_embeddings
,
mask_input
,
training
=
tf
.
keras
.
backend
.
learning_phase
())
training
=
tf
.
keras
.
backend
.
learning_phase
())
mask_gt
=
self
.
_get_groundtruth_mask_output
(
boxes
,
masks
)
mask_gt
=
self
.
_get_groundtruth_mask_output
(
boxes_for_crop
,
masks
)
mask_logits
=
self
.
_resize_logits_like_gt
(
mask_logits
,
mask_gt
)
mask_logits
=
tf
.
reshape
(
mask_logits
,
[
num_instances
,
-
1
,
1
])
mask_prediction_loss
=
self
.
_compute_per_instance_mask_prediction_loss
(
mask_gt
=
tf
.
reshape
(
mask_gt
,
[
num_instances
,
-
1
,
1
])
boxes_for_crop
,
mask_logits
,
mask_gt
)
loss
=
self
.
_deepmac_params
.
classification_loss
(
prediction_tensor
=
mask_logits
,
target_tensor
=
mask_gt
,
weights
=
tf
.
ones_like
(
mask_logits
))
# TODO(vighneshb) Make this configurable via config.
box_consistency_loss
=
self
.
_compute_per_instance_box_consistency_loss
(
if
isinstance
(
self
.
_deepmac_params
.
classification_loss
,
boxes
,
boxes_for_crop
,
mask_logits
)
losses
.
WeightedDiceClassificationLoss
):
return
tf
.
reduce_sum
(
loss
,
axis
=
1
)
return
mask_prediction_loss
,
box_consistency_loss
else
:
return
tf
.
reduce_mean
(
loss
,
axis
=
[
1
,
2
])
def
_compute_instance_masks_loss
(
self
,
prediction_dict
):
def
_compute_instance_masks_loss
(
self
,
prediction_dict
):
"""Computes the mask loss.
"""Computes the mask loss.
...
@@ -603,7 +731,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -603,7 +731,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
[batch_size, height, width, embedding_size].
[batch_size, height, width, embedding_size].
Returns:
Returns:
loss
: float, the mask loss as a scalar
.
loss
_dict: A dict mapping string (loss names) to scalar floats
.
"""
"""
gt_boxes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
boxes
)
gt_boxes_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
boxes
)
gt_weights_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
weights
)
gt_weights_list
=
self
.
groundtruth_lists
(
fields
.
BoxListFields
.
weights
)
...
@@ -613,7 +741,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -613,7 +741,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
allowed_masked_classes_ids
=
(
allowed_masked_classes_ids
=
(
self
.
_deepmac_params
.
allowed_masked_classes_ids
)
self
.
_deepmac_params
.
allowed_masked_classes_ids
)
total_loss
=
0.0
loss_dict
=
{
DEEP_MASK_ESTIMATION
:
0.0
,
DEEP_MASK_BOX_CONSISTENCY
:
0.0
}
# Iterate over multiple preidctions by backbone (for hourglass length=2)
# Iterate over multiple preidctions by backbone (for hourglass length=2)
for
instance_pred
,
pixel_pred
in
zip
(
for
instance_pred
,
pixel_pred
in
zip
(
...
@@ -625,24 +756,31 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -625,24 +756,31 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for
i
,
(
boxes
,
weights
,
classes
,
masks
)
in
enumerate
(
for
i
,
(
boxes
,
weights
,
classes
,
masks
)
in
enumerate
(
zip
(
gt_boxes_list
,
gt_weights_list
,
gt_classes_list
,
gt_masks_list
)):
zip
(
gt_boxes_list
,
gt_weights_list
,
gt_classes_list
,
gt_masks_list
)):
_
,
weights
,
masks
=
filter_masked_classes
(
allowed_masked_classes_ids
,
# TODO(vighneshb) Add sub-sampling back if required.
classes
,
weights
,
masks
)
classes
,
valid_mask_weights
,
masks
=
filter_masked_classes
(
num_subsample
=
self
.
_deepmac_params
.
mask_num_subsamples
allowed_masked_classes_ids
,
classes
,
weights
,
masks
)
_
,
weights
,
boxes
,
masks
=
subsample_instances
(
classes
,
weights
,
boxes
,
masks
,
num_subsample
)
per_instance_loss
=
self
.
_compute_per_instance_mask_loss
(
per_instance_mask_loss
,
per_instance_consistency_loss
=
(
boxes
,
masks
,
instance_pred
[
i
],
pixel_pred
[
i
])
self
.
_compute_per_instance_deepmac_losses
(
per_instance_loss
*=
weights
boxes
,
masks
,
instance_pred
[
i
],
pixel_pred
[
i
]))
per_instance_mask_loss
*=
valid_mask_weights
per_instance_consistency_loss
*=
weights
num_instances
=
tf
.
maximum
(
tf
.
reduce_sum
(
weights
),
1.0
)
num_instances
=
tf
.
maximum
(
tf
.
reduce_sum
(
weights
),
1.0
)
num_instances_allowed
=
tf
.
maximum
(
tf
.
reduce_sum
(
valid_mask_weights
),
1.0
)
total_loss
+=
tf
.
reduce_sum
(
per_instance_loss
)
/
num_instances
loss_dict
[
DEEP_MASK_ESTIMATION
]
+=
(
tf
.
reduce_sum
(
per_instance_mask_loss
)
/
num_instances_allowed
)
loss_dict
[
DEEP_MASK_BOX_CONSISTENCY
]
+=
(
tf
.
reduce_sum
(
per_instance_consistency_loss
)
/
num_instances
)
batch_size
=
len
(
gt_boxes_list
)
batch_size
=
len
(
gt_boxes_list
)
num_predictions
=
len
(
prediction_dict
[
INSTANCE_EMBEDDING
])
num_predictions
=
len
(
prediction_dict
[
INSTANCE_EMBEDDING
])
return
total_loss
/
float
(
batch_size
*
num_predictions
)
return
dict
((
key
,
loss
/
float
(
batch_size
*
num_predictions
))
for
key
,
loss
in
loss_dict
.
items
())
def
loss
(
self
,
prediction_dict
,
true_image_shapes
,
scope
=
None
):
def
loss
(
self
,
prediction_dict
,
true_image_shapes
,
scope
=
None
):
...
@@ -650,13 +788,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -650,13 +788,19 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict
,
true_image_shapes
,
scope
)
prediction_dict
,
true_image_shapes
,
scope
)
if
self
.
_deepmac_params
is
not
None
:
if
self
.
_deepmac_params
is
not
None
:
mask_loss
=
self
.
_compute_instance_masks_loss
(
mask_loss
_dict
=
self
.
_compute_instance_masks_loss
(
prediction_dict
=
prediction_dict
)
prediction_dict
=
prediction_dict
)
key
=
LOSS_KEY_PREFIX
+
'/'
+
DEEP_MASK_ESTIMATION
losses_dict
[
key
]
=
(
losses_dict
[
LOSS_KEY_PREFIX
+
'/'
+
DEEP_MASK_ESTIMATION
]
=
(
self
.
_deepmac_params
.
task_loss_weight
*
mask_loss
self
.
_deepmac_params
.
task_loss_weight
*
mask_loss_dict
[
DEEP_MASK_ESTIMATION
]
)
)
if
self
.
_deepmac_params
.
box_consistency_loss_weight
>
0.0
:
losses_dict
[
LOSS_KEY_PREFIX
+
'/'
+
DEEP_MASK_BOX_CONSISTENCY
]
=
(
self
.
_deepmac_params
.
box_consistency_loss_weight
*
mask_loss_dict
[
DEEP_MASK_BOX_CONSISTENCY
]
)
return
losses_dict
return
losses_dict
def
postprocess
(
self
,
prediction_dict
,
true_image_shapes
,
**
params
):
def
postprocess
(
self
,
prediction_dict
,
true_image_shapes
,
**
params
):
...
...
research/object_detection/meta_architectures/deepmac_meta_arch_test.py
View file @
1cbb7a7d
...
@@ -60,7 +60,8 @@ class MockMaskNet(tf.keras.layers.Layer):
...
@@ -60,7 +60,8 @@ class MockMaskNet(tf.keras.layers.Layer):
return
tf
.
zeros_like
(
pixel_embedding
[:,
:,
:,
0
])
+
0.9
return
tf
.
zeros_like
(
pixel_embedding
[:,
:,
:,
0
])
+
0.9
def
build_meta_arch
(
predict_full_resolution_masks
=
False
,
use_dice_loss
=
False
):
def
build_meta_arch
(
predict_full_resolution_masks
=
False
,
use_dice_loss
=
False
,
mask_num_subsamples
=-
1
):
"""Builds the DeepMAC meta architecture."""
"""Builds the DeepMAC meta architecture."""
feature_extractor
=
DummyFeatureExtractor
(
feature_extractor
=
DummyFeatureExtractor
(
...
@@ -94,7 +95,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
...
@@ -94,7 +95,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
pixel_embedding_dim
=
2
,
pixel_embedding_dim
=
2
,
allowed_masked_classes_ids
=
[],
allowed_masked_classes_ids
=
[],
mask_size
=
16
,
mask_size
=
16
,
mask_num_subsamples
=
-
1
,
mask_num_subsamples
=
mask_num_subsamples
,
use_xy
=
True
,
use_xy
=
True
,
network_type
=
'hourglass10'
,
network_type
=
'hourglass10'
,
use_instance_embedding
=
True
,
use_instance_embedding
=
True
,
...
@@ -102,7 +103,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
...
@@ -102,7 +103,8 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
predict_full_resolution_masks
=
predict_full_resolution_masks
,
predict_full_resolution_masks
=
predict_full_resolution_masks
,
postprocess_crop_size
=
128
,
postprocess_crop_size
=
128
,
max_roi_jitter_ratio
=
0.0
,
max_roi_jitter_ratio
=
0.0
,
roi_jitter_mode
=
'random'
roi_jitter_mode
=
'random'
,
box_consistency_loss_weight
=
1.0
,
)
)
object_detection_params
=
center_net_meta_arch
.
ObjectDetectionParams
(
object_detection_params
=
center_net_meta_arch
.
ObjectDetectionParams
(
...
@@ -140,6 +142,33 @@ class DeepMACUtilsTest(tf.test.TestCase):
...
@@ -140,6 +142,33 @@ class DeepMACUtilsTest(tf.test.TestCase):
self
.
assertAllClose
(
result
[
2
],
boxes
)
self
.
assertAllClose
(
result
[
2
],
boxes
)
self
.
assertAllClose
(
result
[
3
],
masks
)
self
.
assertAllClose
(
result
[
3
],
masks
)
def
test_fill_boxes
(
self
):
boxes
=
tf
.
constant
([[
0.
,
0.
,
0.5
,
0.5
],
[
0.5
,
0.5
,
1.0
,
1.0
]])
filled_boxes
=
deepmac_meta_arch
.
fill_boxes
(
boxes
,
32
,
32
)
expected
=
np
.
zeros
((
2
,
32
,
32
))
expected
[
0
,
:
17
,
:
17
]
=
1.0
expected
[
1
,
16
:,
16
:]
=
1.0
self
.
assertAllClose
(
expected
,
filled_boxes
.
numpy
(),
rtol
=
1e-3
)
def
test_crop_and_resize_instance_masks
(
self
):
boxes
=
tf
.
zeros
((
5
,
4
))
masks
=
tf
.
zeros
((
5
,
128
,
128
))
output
=
deepmac_meta_arch
.
crop_and_resize_instance_masks
(
masks
,
boxes
,
32
)
self
.
assertEqual
(
output
.
shape
,
(
5
,
32
,
32
))
def
test_crop_and_resize_feature_map
(
self
):
boxes
=
tf
.
zeros
((
5
,
4
))
features
=
tf
.
zeros
((
128
,
128
,
7
))
output
=
deepmac_meta_arch
.
crop_and_resize_feature_map
(
features
,
boxes
,
32
)
self
.
assertEqual
(
output
.
shape
,
(
5
,
32
,
32
,
7
))
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
DeepMACMetaArchTest
(
tf
.
test
.
TestCase
):
class
DeepMACMetaArchTest
(
tf
.
test
.
TestCase
):
...
@@ -199,7 +228,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -199,7 +228,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
def
test_get_mask_head_input_no_crop_resize
(
self
):
def
test_get_mask_head_input_no_crop_resize
(
self
):
model
=
build_meta_arch
(
predict_full_resolution_masks
=
True
)
model
=
build_meta_arch
(
predict_full_resolution_masks
=
True
)
boxes
=
tf
.
constant
([[
0.
,
0.
,
0
.0
,
0
.0
],
[
0.0
,
0.0
,
0.
0
,
0
.0
]],
boxes
=
tf
.
constant
([[
0.
,
0.
,
1
.0
,
1
.0
],
[
0.0
,
0.0
,
0.
5
,
1
.0
]],
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
pixel_embedding_np
=
np
.
random
.
randn
(
32
,
32
,
4
).
astype
(
np
.
float32
)
pixel_embedding_np
=
np
.
random
.
randn
(
32
,
32
,
4
).
astype
(
np
.
float32
)
...
@@ -208,12 +237,15 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -208,12 +237,15 @@ class DeepMACMetaArchTest(tf.test.TestCase):
mask_inputs
=
model
.
_get_mask_head_input
(
boxes
,
pixel_embedding
)
mask_inputs
=
model
.
_get_mask_head_input
(
boxes
,
pixel_embedding
)
self
.
assertEqual
(
mask_inputs
.
shape
,
(
2
,
32
,
32
,
6
))
self
.
assertEqual
(
mask_inputs
.
shape
,
(
2
,
32
,
32
,
6
))
y_grid
,
x_grid
=
tf
.
meshgrid
(
np
.
linspace
(
-
1.0
,
1.0
,
32
),
y_grid
,
x_grid
=
tf
.
meshgrid
(
np
.
linspace
(.
0
,
1.0
,
32
),
np
.
linspace
(
-
1.0
,
1.0
,
32
),
indexing
=
'ij'
)
np
.
linspace
(.
0
,
1.0
,
32
),
indexing
=
'ij'
)
ys
=
[
0.5
,
0.25
]
xs
=
[
0.5
,
0.5
]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
mask_input
=
mask_inputs
[
i
]
mask_input
=
mask_inputs
[
i
]
self
.
assertAllClose
(
y_grid
,
mask_input
[:,
:,
0
])
self
.
assertAllClose
(
y_grid
-
ys
[
i
]
,
mask_input
[:,
:,
0
])
self
.
assertAllClose
(
x_grid
,
mask_input
[:,
:,
1
])
self
.
assertAllClose
(
x_grid
-
xs
[
i
]
,
mask_input
[:,
:,
1
])
pixel_embedding
=
mask_input
[:,
:,
2
:]
pixel_embedding
=
mask_input
[:,
:,
2
:]
self
.
assertAllClose
(
pixel_embedding_np
,
pixel_embedding
)
self
.
assertAllClose
(
pixel_embedding_np
,
pixel_embedding
)
...
@@ -262,7 +294,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -262,7 +294,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks
[
1
,
16
:,
16
:]
=
1.0
masks
[
1
,
16
:,
16
:]
=
1.0
masks
=
tf
.
constant
(
masks
)
masks
=
tf
.
constant
(
masks
)
loss
=
model
.
_compute_per_instance_
mask
_loss
(
loss
,
_
=
model
.
_compute_per_instance_
deepmac
_loss
es
(
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
self
.
assertAllClose
(
self
.
assertAllClose
(
loss
,
np
.
zeros
(
2
)
-
tf
.
math
.
log
(
tf
.
nn
.
sigmoid
(
0.9
)))
loss
,
np
.
zeros
(
2
)
-
tf
.
math
.
log
(
tf
.
nn
.
sigmoid
(
0.9
)))
...
@@ -275,7 +307,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -275,7 +307,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks
=
np
.
ones
((
2
,
128
,
128
),
dtype
=
np
.
float32
)
masks
=
np
.
ones
((
2
,
128
,
128
),
dtype
=
np
.
float32
)
masks
=
tf
.
constant
(
masks
)
masks
=
tf
.
constant
(
masks
)
loss
=
model
.
_compute_per_instance_
mask
_loss
(
loss
,
_
=
model
.
_compute_per_instance_
deepmac
_loss
es
(
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
self
.
assertAllClose
(
self
.
assertAllClose
(
loss
,
np
.
zeros
(
2
)
-
tf
.
math
.
log
(
tf
.
nn
.
sigmoid
(
0.9
)))
loss
,
np
.
zeros
(
2
)
-
tf
.
math
.
log
(
tf
.
nn
.
sigmoid
(
0.9
)))
...
@@ -289,7 +321,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -289,7 +321,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
masks
=
np
.
ones
((
2
,
128
,
128
),
dtype
=
np
.
float32
)
masks
=
np
.
ones
((
2
,
128
,
128
),
dtype
=
np
.
float32
)
masks
=
tf
.
constant
(
masks
)
masks
=
tf
.
constant
(
masks
)
loss
=
model
.
_compute_per_instance_
mask
_loss
(
loss
,
_
=
model
.
_compute_per_instance_
deepmac
_loss
es
(
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
pred
=
tf
.
nn
.
sigmoid
(
0.9
)
pred
=
tf
.
nn
.
sigmoid
(
0.9
)
expected
=
(
1.0
-
((
2.0
*
pred
)
/
(
1.0
+
pred
)))
expected
=
(
1.0
-
((
2.0
*
pred
)
/
(
1.0
+
pred
)))
...
@@ -299,7 +331,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -299,7 +331,7 @@ class DeepMACMetaArchTest(tf.test.TestCase):
boxes
=
tf
.
zeros
([
0
,
4
])
boxes
=
tf
.
zeros
([
0
,
4
])
masks
=
tf
.
zeros
([
0
,
128
,
128
])
masks
=
tf
.
zeros
([
0
,
128
,
128
])
loss
=
self
.
model
.
_compute_per_instance_
mask
_loss
(
loss
,
_
=
self
.
model
.
_compute_per_instance_
deepmac
_loss
es
(
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
boxes
,
masks
,
tf
.
zeros
((
32
,
32
,
2
)),
tf
.
zeros
((
32
,
32
,
2
)))
self
.
assertEqual
(
loss
.
shape
,
(
0
,))
self
.
assertEqual
(
loss
.
shape
,
(
0
,))
...
@@ -394,6 +426,59 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -394,6 +426,59 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out
=
call_func
(
tf
.
zeros
((
2
,
4
)),
tf
.
zeros
((
2
,
32
,
32
,
8
)),
training
=
True
)
out
=
call_func
(
tf
.
zeros
((
2
,
4
)),
tf
.
zeros
((
2
,
32
,
32
,
8
)),
training
=
True
)
self
.
assertEqual
(
out
.
shape
,
(
2
,
32
,
32
))
self
.
assertEqual
(
out
.
shape
,
(
2
,
32
,
32
))
def
test_box_consistency_loss
(
self
):
boxes_gt
=
tf
.
constant
([[
0.
,
0.
,
0.49
,
1.0
]])
boxes_jittered
=
tf
.
constant
([[
0.0
,
0.0
,
1.0
,
1.0
]])
mask_prediction
=
np
.
zeros
((
1
,
32
,
32
)).
astype
(
np
.
float32
)
mask_prediction
[
0
,
:
24
,
:
24
]
=
1.0
loss
=
self
.
model
.
_compute_per_instance_box_consistency_loss
(
boxes_gt
,
boxes_jittered
,
tf
.
constant
(
mask_prediction
))
yloss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
tf
.
constant
([
1.0
]
*
8
+
[
0.0
]
*
8
),
logits
=
[
1.0
]
*
12
+
[
0.0
]
*
4
)
xloss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
tf
.
constant
([
1.0
]
*
16
),
logits
=
[
1.0
]
*
12
+
[
0.0
]
*
4
)
self
.
assertAllClose
(
loss
,
[
tf
.
reduce_mean
(
yloss
+
xloss
).
numpy
()])
def
test_box_consistency_dice_loss
(
self
):
model
=
build_meta_arch
(
use_dice_loss
=
True
)
boxes_gt
=
tf
.
constant
([[
0.
,
0.
,
0.49
,
1.0
]])
boxes_jittered
=
tf
.
constant
([[
0.0
,
0.0
,
1.0
,
1.0
]])
almost_inf
=
1e10
mask_prediction
=
np
.
full
((
1
,
32
,
32
),
-
almost_inf
,
dtype
=
np
.
float32
)
mask_prediction
[
0
,
:
24
,
:
24
]
=
almost_inf
loss
=
model
.
_compute_per_instance_box_consistency_loss
(
boxes_gt
,
boxes_jittered
,
tf
.
constant
(
mask_prediction
))
yloss
=
1
-
6.0
/
7
xloss
=
0.2
self
.
assertAllClose
(
loss
,
[
yloss
+
xloss
])
def
test_box_consistency_dice_loss_full_res
(
self
):
model
=
build_meta_arch
(
use_dice_loss
=
True
,
predict_full_resolution_masks
=
True
)
boxes_gt
=
tf
.
constant
([[
0.
,
0.
,
1.0
,
1.0
]])
boxes_jittered
=
None
almost_inf
=
1e10
mask_prediction
=
np
.
full
((
1
,
32
,
32
),
-
almost_inf
,
dtype
=
np
.
float32
)
mask_prediction
[
0
,
:
16
,
:
32
]
=
almost_inf
loss
=
model
.
_compute_per_instance_box_consistency_loss
(
boxes_gt
,
boxes_jittered
,
tf
.
constant
(
mask_prediction
))
self
.
assertAlmostEqual
(
loss
[
0
].
numpy
(),
1
/
3
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
FullyConnectedMaskHeadTest
(
tf
.
test
.
TestCase
):
class
FullyConnectedMaskHeadTest
(
tf
.
test
.
TestCase
):
...
...
research/object_detection/protos/center_net.proto
View file @
1cbb7a7d
...
@@ -446,6 +446,10 @@ message CenterNet {
...
@@ -446,6 +446,10 @@ message CenterNet {
// The mode for jitterting box ROIs. See RandomJitterBoxes in
// The mode for jitterting box ROIs. See RandomJitterBoxes in
// preprocessor.proto for more details
// preprocessor.proto for more details
optional
RandomJitterBoxes.JitterMode
jitter_mode
=
15
[
default
=
DEFAULT
];
optional
RandomJitterBoxes.JitterMode
jitter_mode
=
15
[
default
=
DEFAULT
];
// Weight for the box consistency loss as described in the BoxInst paper
// https://arxiv.org/abs/2012.02310
optional
float
box_consistency_loss_weight
=
16
[
default
=
0.0
];
}
}
optional
DeepMACMaskEstimation
deepmac_mask_estimation
=
14
;
optional
DeepMACMaskEstimation
deepmac_mask_estimation
=
14
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment