Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fbfa5038
Commit
fbfa5038
authored
Sep 14, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 474374013
parent
2fe71495
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
304 additions
and
8 deletions
+304
-8
official/projects/panoptic/modeling/layers/panoptic_segmentation_generator.py
...noptic/modeling/layers/panoptic_segmentation_generator.py
+304
-8
No files found.
official/projects/panoptic/modeling/layers/panoptic_segmentation_generator.py
View file @
fbfa5038
...
@@ -14,11 +14,33 @@
...
@@ -14,11 +14,33 @@
"""Contains definition for postprocessing layer to genrate panoptic segmentations."""
"""Contains definition for postprocessing layer to genrate panoptic segmentations."""
from
typing
import
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.projects.panoptic.modeling.layers
import
paste_masks
from
official.projects.panoptic.modeling.layers
import
paste_masks
from
official.vision.ops
import
spatial_transform_ops
def
_batch_count_ones
(
masks
:
tf
.
Tensor
,
dtype
:
tf
.
dtypes
.
DType
=
tf
.
int32
)
->
tf
.
Tensor
:
"""Counts the ones/trues for each mask in the batch.
Args:
masks: A tensor in shape (..., height, width) with arbitrary numbers of
batch dimensions.
dtype: DType of the resulting tensor. Default is tf.int32.
Returns:
A tensor which contains the count of non-zero elements for each mask in the
batch. The rank of the resulting tensor is equal to rank(masks) - 2.
"""
masks_shape
=
masks
.
get_shape
().
as_list
()
if
len
(
masks_shape
)
<
2
:
raise
ValueError
(
'Expected the input masks (..., height, width) has rank >= 2, was: %s'
%
masks_shape
)
return
tf
.
reduce_sum
(
tf
.
cast
(
masks
,
dtype
),
axis
=
[
-
2
,
-
1
])
class
PanopticSegmentationGenerator
(
tf
.
keras
.
layers
.
Layer
):
class
PanopticSegmentationGenerator
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -88,15 +110,18 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
...
@@ -88,15 +110,18 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
'void_instance_id'
:
void_instance_id
,
'void_instance_id'
:
void_instance_id
,
'rescale_predictions'
:
rescale_predictions
'rescale_predictions'
:
rescale_predictions
}
}
super
(
PanopticSegmentationGenerator
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
:
tf
.
TensorShape
):
grid_sampler
=
paste_masks
.
BilinearGridSampler
(
align_corners
=
False
)
grid_sampler
=
paste_masks
.
BilinearGridSampler
(
align_corners
=
False
)
self
.
_paste_masks_fn
=
paste_masks
.
PasteMasks
(
self
.
_paste_masks_fn
=
paste_masks
.
PasteMasks
(
output_size
=
self
.
_output_size
,
grid_sampler
=
grid_sampler
)
output_size
=
self
.
_output_size
,
grid_sampler
=
grid_sampler
)
super
().
build
(
input_shape
)
def
_generate_panoptic_masks
(
self
,
boxes
,
scores
,
classes
,
detections_masks
,
def
_generate_panoptic_masks
(
segmentation_mask
):
self
,
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
classes
:
tf
.
Tensor
,
detections_masks
:
tf
.
Tensor
,
segmentation_mask
:
tf
.
Tensor
)
->
Dict
[
str
,
tf
.
Tensor
]:
"""Generates panoptic masks for a single image.
"""Generates panoptic masks for a single image.
This function implements the following steps to merge instance and semantic
This function implements the following steps to merge instance and semantic
...
@@ -260,7 +285,9 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
...
@@ -260,7 +285,9 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
mask
,
0
,
0
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
mask
,
0
,
0
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
return
mask
return
mask
def
call
(
self
,
inputs
:
tf
.
Tensor
,
image_info
:
Optional
[
tf
.
Tensor
]
=
None
):
def
call
(
self
,
inputs
:
tf
.
Tensor
,
image_info
:
Optional
[
tf
.
Tensor
]
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
detections
=
inputs
detections
=
inputs
batched_scores
=
detections
[
'detection_scores'
]
batched_scores
=
detections
[
'detection_scores'
]
...
@@ -313,9 +340,278 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
...
@@ -313,9 +340,278 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
return
panoptic_masks
return
panoptic_masks
def
get_config
(
self
):
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
'PanopticSegmentationGenerator'
:
return
cls
(
**
config
)
class
PanopticSegmentationGeneratorV2
(
tf
.
keras
.
layers
.
Layer
):
"""Panoptic segmentation generator layer V2."""
def
__init__
(
self
,
output_size
:
List
[
int
],
max_num_detections
:
int
,
stuff_classes_offset
:
int
,
mask_binarize_threshold
:
float
=
0.5
,
score_threshold
:
float
=
0.5
,
things_overlap_threshold
:
float
=
0.5
,
stuff_area_threshold
:
float
=
4096
,
things_class_label
:
int
=
1
,
void_class_label
:
int
=
0
,
void_instance_id
:
int
=
-
1
,
rescale_predictions
:
bool
=
False
,
**
kwargs
):
"""Generates panoptic segmentation masks.
Args:
output_size: A `List` of integers that represent the height and width of
the output mask.
max_num_detections: `int` for maximum number of detections.
stuff_classes_offset: An `int` that is added to the output of the semantic
segmentation mask to make sure that the stuff class ids do not ovelap
with the thing class ids of the MaskRCNN outputs.
mask_binarize_threshold: A `float`
score_threshold: A `float` representing the threshold for deciding when to
remove objects based on score.
things_overlap_threshold: A `float` representing a threshold for deciding
to ignore a thing if overlap is above the threshold.
stuff_area_threshold: A `float` representing a threshold for deciding to
to ignore a stuff class if area is below certain threshold.
things_class_label: An `int` that represents a single merged category of
all thing classes in the semantic segmentation output.
void_class_label: An `int` that is used to represent empty or unlabelled
regions of the mask
void_instance_id: An `int` that is used to denote regions that are not
assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info is used to rescale predictions.
**kwargs: additional kewargs arguments.
"""
self
.
_output_size
=
output_size
self
.
_max_num_detections
=
max_num_detections
self
.
_stuff_classes_offset
=
stuff_classes_offset
self
.
_mask_binarize_threshold
=
mask_binarize_threshold
self
.
_score_threshold
=
score_threshold
self
.
_things_overlap_threshold
=
things_overlap_threshold
self
.
_stuff_area_threshold
=
stuff_area_threshold
self
.
_things_class_label
=
things_class_label
self
.
_void_class_label
=
void_class_label
self
.
_void_instance_id
=
void_instance_id
self
.
_rescale_predictions
=
rescale_predictions
self
.
_config_dict
=
{
'output_size'
:
output_size
,
'max_num_detections'
:
max_num_detections
,
'stuff_classes_offset'
:
stuff_classes_offset
,
'mask_binarize_threshold'
:
mask_binarize_threshold
,
'score_threshold'
:
score_threshold
,
'things_class_label'
:
things_class_label
,
'void_class_label'
:
void_class_label
,
'void_instance_id'
:
void_instance_id
,
'rescale_predictions'
:
rescale_predictions
}
super
().
__init__
(
**
kwargs
)
def
call
(
self
,
inputs
:
tf
.
Tensor
,
image_info
:
Optional
[
tf
.
Tensor
]
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
"""Generates panoptic segmentation masks."""
# (batch_size, num_rois, 4) in absolute coordinates.
detection_boxes
=
tf
.
cast
(
inputs
[
'detection_boxes'
],
tf
.
float32
)
# (batch_size, num_rois)
detection_classes
=
tf
.
cast
(
inputs
[
'detection_classes'
],
tf
.
int32
)
# (batch_size, num_rois)
detection_scores
=
tf
.
cast
(
inputs
[
'detection_scores'
],
tf
.
float32
)
# (batch_size, num_rois, mask_height, mask_width)
detections_masks
=
tf
.
cast
(
inputs
[
'detection_masks'
],
tf
.
float32
)
# (batch_size, height, width, num_semantic_classes)
segmentation_outputs
=
tf
.
cast
(
inputs
[
'segmentation_outputs'
],
tf
.
float32
)
if
self
.
_rescale_predictions
:
# (batch_size, 2)
original_size
=
tf
.
cast
(
image_info
[:,
0
,
:],
tf
.
float32
)
desired_size
=
tf
.
cast
(
image_info
[:,
1
,
:],
tf
.
float32
)
image_scale
=
tf
.
cast
(
image_info
[:,
2
,
:],
tf
.
float32
)
offset
=
tf
.
cast
(
image_info
[:,
3
,
:],
tf
.
float32
)
rescale_size
=
tf
.
math
.
ceil
(
desired_size
/
image_scale
)
# (batch_size, output_height, output_width, num_semantic_classes)
segmentation_outputs
=
(
spatial_transform_ops
.
bilinear_resize_with_crop_and_pad
(
segmentation_outputs
,
rescale_size
,
crop_offset
=
offset
,
crop_size
=
original_size
,
output_size
=
self
.
_output_size
))
# (batch_size, 1, 4)
image_scale
=
tf
.
tile
(
image_scale
,
multiples
=
[
1
,
2
])[:,
tf
.
newaxis
]
detection_boxes
/=
image_scale
else
:
# (batch_size, output_height, output_width, num_semantic_classes)
segmentation_outputs
=
tf
.
image
.
resize
(
segmentation_outputs
,
size
=
self
.
_output_size
,
method
=
'bilinear'
)
# (batch_size, output_height, output_width)
instance_mask
,
instance_category_mask
=
self
.
_generate_instances
(
detection_boxes
,
detection_classes
,
detection_scores
,
detections_masks
)
# (batch_size, output_height, output_width)
stuff_category_mask
=
self
.
_generate_stuffs
(
segmentation_outputs
)
# (batch_size, output_height, output_width)
category_mask
=
tf
.
where
((
stuff_category_mask
!=
self
.
_void_class_label
)
&
(
instance_category_mask
==
self
.
_void_class_label
),
stuff_category_mask
+
self
.
_stuff_classes_offset
,
instance_category_mask
)
return
{
'instance_mask'
:
instance_mask
,
'category_mask'
:
category_mask
}
def
_generate_instances
(
self
,
detection_boxes
:
tf
.
Tensor
,
detection_classes
:
tf
.
Tensor
,
detection_scores
:
tf
.
Tensor
,
detections_masks
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Generates instance & category masks from instance segmentation outputs."""
batch_size
=
tf
.
shape
(
detections_masks
)[
0
]
num_rois
=
tf
.
shape
(
detections_masks
)[
1
]
mask_height
=
tf
.
shape
(
detections_masks
)[
2
]
mask_width
=
tf
.
shape
(
detections_masks
)[
3
]
output_height
=
self
.
_output_size
[
0
]
output_width
=
self
.
_output_size
[
1
]
# (batch_size, num_rois, mask_height, mask_width)
detections_masks
=
detections_masks
*
(
tf
.
cast
((
detection_scores
>
self
.
_score_threshold
)
&
(
detection_classes
!=
self
.
_void_class_label
),
detections_masks
.
dtype
)[:,
:,
tf
.
newaxis
,
tf
.
newaxis
])
# Resizes and copies the detections_masks to the bounding boxes in the
# output canvas.
# (batch_size, num_rois, output_height, output_width)
pasted_detection_masks
=
tf
.
reshape
(
spatial_transform_ops
.
bilinear_resize_to_bbox
(
tf
.
reshape
(
detections_masks
,
[
-
1
,
mask_height
,
mask_width
]),
tf
.
reshape
(
detection_boxes
,
[
-
1
,
4
]),
self
.
_output_size
),
shape
=
[
-
1
,
num_rois
,
output_height
,
output_width
])
# (batch_size, num_rois, output_height, output_width)
instance_binary_masks
=
(
pasted_detection_masks
>
self
.
_mask_binarize_threshold
)
# Sorts detection related tensors by scores.
# (batch_size, num_rois)
sorted_detection_indices
=
tf
.
argsort
(
detection_scores
,
axis
=
1
,
direction
=
'DESCENDING'
)
# (batch_size, num_rois)
sorted_detection_classes
=
tf
.
gather
(
detection_classes
,
sorted_detection_indices
,
batch_dims
=
1
)
# (batch_size, num_rois, output_height, output_width)
sorted_instance_binary_masks
=
tf
.
gather
(
instance_binary_masks
,
sorted_detection_indices
,
batch_dims
=
1
)
# (batch_size, num_rois)
instance_areas
=
_batch_count_ones
(
sorted_instance_binary_masks
,
dtype
=
tf
.
float32
)
init_loop_vars
=
(
0
,
# i: the loop counter
tf
.
ones
([
batch_size
,
output_height
,
output_width
],
dtype
=
tf
.
int32
)
*
self
.
_void_instance_id
,
# combined_instance_mask
tf
.
ones
([
batch_size
,
output_height
,
output_width
],
dtype
=
tf
.
int32
)
*
self
.
_void_class_label
# combined_category_mask
)
def
_copy_instances_loop_body
(
i
:
int
,
combined_instance_mask
:
tf
.
Tensor
,
combined_category_mask
:
tf
.
Tensor
)
->
Tuple
[
int
,
tf
.
Tensor
,
tf
.
Tensor
]:
"""Iterates the sorted detections and copies the instances."""
# (batch_size, output_height, output_width)
instance_binary_mask
=
sorted_instance_binary_masks
[:,
i
]
# Masks out the instances that have a big enough overlap with the other
# instances with higher scores.
# (batch_size, )
overlap_areas
=
_batch_count_ones
(
(
combined_instance_mask
!=
self
.
_void_instance_id
)
&
instance_binary_mask
,
dtype
=
tf
.
float32
)
# (batch_size, )
instance_overlap_threshold_mask
=
tf
.
math
.
divide_no_nan
(
overlap_areas
,
instance_areas
[:,
i
])
<
self
.
_things_overlap_threshold
# (batch_size, output_height, output_width)
instance_binary_mask
&=
(
instance_overlap_threshold_mask
[:,
tf
.
newaxis
,
tf
.
newaxis
]
&
(
combined_instance_mask
==
self
.
_void_instance_id
))
# Updates combined_instance_mask.
# (batch_size, )
instance_id
=
tf
.
cast
(
sorted_detection_indices
[:,
i
]
+
1
,
# starting from 1
dtype
=
combined_instance_mask
.
dtype
)
# (batch_size, output_height, output_width)
combined_instance_mask
=
tf
.
where
(
instance_binary_mask
,
instance_id
[:,
tf
.
newaxis
,
tf
.
newaxis
],
combined_instance_mask
)
# Updates combined_category_mask.
# (batch_size, )
class_id
=
tf
.
cast
(
sorted_detection_classes
[:,
i
],
dtype
=
combined_category_mask
.
dtype
)
# (batch_size, output_height, output_width)
combined_category_mask
=
tf
.
where
(
instance_binary_mask
,
class_id
[:,
tf
.
newaxis
,
tf
.
newaxis
],
combined_category_mask
)
# Returns the updated loop vars.
return
(
i
+
1
,
# Increment the loop counter i
combined_instance_mask
,
combined_category_mask
)
# (batch_size, output_height, output_width)
_
,
instance_mask
,
category_mask
=
tf
.
while_loop
(
cond
=
lambda
i
,
*
_
:
i
<
num_rois
-
1
,
body
=
_copy_instances_loop_body
,
loop_vars
=
init_loop_vars
,
parallel_iterations
=
32
,
maximum_iterations
=
num_rois
)
return
instance_mask
,
category_mask
def
_generate_stuffs
(
self
,
segmentation_outputs
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Generates category mask from semantic segmentation outputs."""
num_semantic_classes
=
tf
.
shape
(
segmentation_outputs
)[
3
]
# (batch_size, output_height, output_width)
segmentation_masks
=
tf
.
argmax
(
segmentation_outputs
,
axis
=-
1
,
output_type
=
tf
.
int32
)
stuff_binary_masks
=
(
segmentation_masks
!=
self
.
_things_class_label
)
&
(
segmentation_masks
!=
self
.
_void_class_label
)
# (batch_size, num_semantic_classes, output_height, output_width)
stuff_class_binary_masks
=
((
tf
.
one_hot
(
segmentation_masks
,
num_semantic_classes
,
axis
=
1
,
dtype
=
tf
.
int32
)
==
1
)
&
tf
.
expand_dims
(
stuff_binary_masks
,
axis
=
1
))
# Masks out the stuff class whose area is below the given threshold.
# (batch_size, num_semantic_classes)
stuff_class_areas
=
_batch_count_ones
(
stuff_class_binary_masks
,
dtype
=
tf
.
float32
)
# (batch_size, num_semantic_classes, output_height, output_width)
stuff_class_binary_masks
&=
tf
.
greater
(
stuff_class_areas
,
self
.
_stuff_area_threshold
)[:,
:,
tf
.
newaxis
,
tf
.
newaxis
]
# (batch_size, output_height, output_width)
stuff_binary_masks
=
tf
.
reduce_any
(
stuff_class_binary_masks
,
axis
=
1
)
# (batch_size, output_height, output_width)
return
tf
.
where
(
stuff_binary_masks
,
segmentation_masks
,
tf
.
ones_like
(
segmentation_masks
)
*
self
.
_void_class_label
)
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_config_dict
return
self
.
_config_dict
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
):
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
'PanopticSegmentationGeneratorV2'
:
return
cls
(
**
config
)
return
cls
(
**
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment